Skip to main content

ai_agent/tools/
deferred_tools.rs

1// Source: ~/claudecode/openclaudecode/src/tools/ToolSearchTool/prompt.ts
2// and ~/claudecode/openclaudecode/src/utils/toolSearch.ts
3use crate::tools::config_tools::TOOL_SEARCH_TOOL_NAME;
4use crate::types::ToolDefinition;
5use crate::utils::env_utils;
6use std::collections::HashSet;
7
8/// Check if a tool should be deferred (requires ToolSearch to load).
9/// A tool is deferred if:
10/// - It's an MCP tool (always deferred)
11/// - It has should_defer: true
12///
13/// A tool is NEVER deferred if:
14/// - It has always_load: true
15/// - It's the ToolSearchTool itself
16/// - It's one of the special exceptions (Brief, SendUserFile, Agent when fork enabled)
17pub fn is_deferred_tool(tool: &ToolDefinition) -> bool {
18    // Explicit opt-out via always_load — tool appears in initial prompt
19    if tool.always_load == Some(true) {
20        return false;
21    }
22
23    // MCP tools are always deferred
24    if tool.is_mcp == Some(true) {
25        return true;
26    }
27
28    // Never defer ToolSearch itself — the model needs it to load everything else
29    if tool.name == TOOL_SEARCH_TOOL_NAME {
30        return false;
31    }
32
33    // Fork-first experiment: Agent must be available turn 1
34    // (Simplified: if fork_subagent feature would be on, don't defer Agent)
35    // For now, we don't defer Agent by default in the Rust SDK
36    if tool.name == "Agent" {
37        return false;
38    }
39
40    return tool.should_defer == Some(true);
41}
42
43/// Format one deferred-tool line for the <available-deferred-tools> message
44pub fn format_deferred_tool_line(tool: &ToolDefinition) -> String {
45    tool.name.clone()
46}
47
48/// Get the list of deferred tool names from a tool list
49pub fn get_deferred_tool_names(tools: &[ToolDefinition]) -> Vec<String> {
50    tools
51        .iter()
52        .filter(|t| is_deferred_tool(t))
53        .map(|t| t.name.clone())
54        .collect()
55}
56
57/// Build the <available-deferred-tools> block content
58pub fn build_available_deferred_tools_block(tools: &[ToolDefinition]) -> String {
59    let deferred_names: Vec<String> = get_deferred_tool_names(tools);
60    if deferred_names.is_empty() {
61        return String::new();
62    }
63    format!(
64        "<available-deferred-tools>\n{}\n</available-deferred-tools>",
65        deferred_names.join("\n")
66    )
67}
68
69/// Extract discovered tool names from message history.
70/// Scans for tool_reference blocks in tool_result content.
71/// Returns the set of tool names that have been discovered via tool_reference blocks.
72pub fn extract_discovered_tool_names(messages: &[serde_json::Value]) -> HashSet<String> {
73    let mut discovered = HashSet::new();
74
75    for msg in messages {
76        // Only user messages contain tool_result blocks
77        if msg.get("role").and_then(|v| v.as_str()) != Some("user") {
78            continue;
79        }
80
81        let content = match msg.get("content") {
82            Some(c) => c,
83            None => continue,
84        };
85
86        // Content can be a string (JSON-encoded) or an array of content blocks
87        // First, try to parse it as JSON if it's a string
88        let content_value = if let Some(content_str) = content.as_str() {
89            // Try to parse the string as JSON
90            match serde_json::from_str::<serde_json::Value>(content_str) {
91                Ok(parsed) => parsed,
92                Err(_) => continue, // Not valid JSON, skip
93            }
94        } else {
95            content.clone()
96        };
97
98        // Now look for tool_reference blocks
99        if let Some(content_array) = content_value.as_array() {
100            for block in content_array {
101                // tool_reference blocks appear inside tool_result content
102                if let Some(block_array) = block.get("content").and_then(|v| v.as_array()) {
103                    for item in block_array {
104                        if item.get("type").and_then(|v| v.as_str()) == Some("tool_reference") {
105                            if let Some(tool_name) = item.get("tool_name").and_then(|v| v.as_str())
106                            {
107                                discovered.insert(tool_name.to_string());
108                            }
109                        }
110                    }
111                }
112            }
113        }
114    }
115
116    discovered
117}
118
119/// Get tool search mode: "tst", "tst-auto", or "standard"
120pub fn get_tool_search_mode() -> &'static str {
121    // Check kill switch
122    if env_utils::is_env_truthy(
123        std::env::var("CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS")
124            .ok()
125            .as_deref(),
126    ) {
127        return "standard";
128    }
129
130    let value = std::env::var("ENABLE_TOOL_SEARCH").ok();
131
132    // Handle auto:N syntax
133    if let Some(ref v) = value {
134        if let Some(percent) = parse_auto_percentage(v) {
135            if percent == 0 {
136                return "tst";
137            }
138            if percent == 100 {
139                return "standard";
140            }
141            return "tst-auto";
142        }
143    }
144
145    if env_utils::is_env_truthy(value.as_deref()) {
146        return "tst";
147    }
148    if env_utils::is_env_defined_falsy(value.as_deref()) {
149        return "standard";
150    }
151    // Default: always defer MCP and shouldDefer tools
152    "tst"
153}
154
155/// Parse auto:N percentage from ENABLE_TOOL_SEARCH
156fn parse_auto_percentage(value: &str) -> Option<i32> {
157    if !value.starts_with("auto:") {
158        return None;
159    }
160    let percent_str = &value[5..];
161    percent_str.parse::<i32>().ok().map(|p| p.max(0).min(100))
162}
163
164/// Check if tool search might be enabled (optimistic check).
165/// Returns true if tool search could potentially be enabled.
166pub fn is_tool_search_enabled_optimistic() -> bool {
167    let mode = get_tool_search_mode();
168    if mode == "standard" {
169        return false;
170    }
171    // Check if using a proxy that might not support tool_reference
172    if std::env::var("ENABLE_TOOL_SEARCH").is_err() {
173        if let Ok(base_url) = std::env::var("ANTHROPIC_BASE_URL") {
174            let first_party_hosts = ["api.anthropic.com", "api.anthropic.ai"];
175            if !first_party_hosts.iter().any(|h| base_url.contains(h)) {
176                return false;
177            }
178        }
179    }
180    true
181}
182
183/// Parse a ToolSearchTool query into (select_tools, keyword_query)
184/// "select:Read,Edit,Grep" -> (["Read", "Edit", "Grep"], None)
185/// "notebook jupyter" -> ([], Some("notebook jupyter"))
186/// "+slack send" -> (required: ["slack"], optional: ["send"])
187pub fn parse_tool_search_query(query: &str) -> ToolSearchQuery {
188    // Check for select: prefix
189    if let Some(rest) = query.strip_prefix("select:") {
190        let tools: Vec<String> = rest
191            .split(',')
192            .map(|s| s.trim().to_string())
193            .filter(|s| !s.is_empty())
194            .collect();
195        return ToolSearchQuery::Select(tools);
196    }
197
198    // Check for +prefixed required terms
199    let terms: Vec<&str> = query.split_whitespace().collect();
200    let mut required = Vec::new();
201    let mut optional = Vec::new();
202
203    for term in &terms {
204        if term.starts_with('+') && term.len() > 1 {
205            required.push(term[1..].to_string());
206        } else {
207            optional.push(term.to_string());
208        }
209    }
210
211    if required.is_empty() && optional.is_empty() {
212        ToolSearchQuery::Keyword(query.to_string())
213    } else if required.is_empty() {
214        ToolSearchQuery::Keyword(query.to_string())
215    } else {
216        ToolSearchQuery::KeywordWithRequired { required, optional }
217    }
218}
219
220#[derive(Debug, Clone)]
221pub enum ToolSearchQuery {
222    /// Direct selection: "select:Read,Edit"
223    Select(Vec<String>),
224    /// Simple keyword search
225    Keyword(String),
226    /// Keyword search with required terms
227    KeywordWithRequired {
228        required: Vec<String>,
229        optional: Vec<String>,
230    },
231}
232
233/// Parse tool name into searchable parts (handles CamelCase and mcp__server__tool)
234pub fn parse_tool_name(name: &str) -> ToolNameParts {
235    // Check if it's an MCP tool
236    if name.starts_with("mcp__") {
237        let without_prefix = &name[5..];
238        let parts: Vec<String> = without_prefix
239            .split("__")
240            .flat_map(|p| p.split('_'))
241            .filter(|s| !s.is_empty())
242            .map(|s| s.to_lowercase())
243            .collect();
244        return ToolNameParts {
245            parts,
246            full: without_prefix
247                .replace("__", " ")
248                .replace('_', " ")
249                .to_lowercase(),
250            is_mcp: true,
251        };
252    }
253
254    // Regular tool - split by CamelCase
255    let spaced = name.replace("([a-z])([A-Z])", "$1 $2").replace('_', " ");
256
257    let parts: Vec<String> = spaced
258        .split_whitespace()
259        .map(|s| s.to_lowercase())
260        .collect();
261
262    let full = parts.join(" ");
263
264    ToolNameParts {
265        parts,
266        full,
267        is_mcp: false,
268    }
269}
270
271#[derive(Debug, Clone)]
272pub struct ToolNameParts {
273    pub parts: Vec<String>,
274    pub full: String,
275    pub is_mcp: bool,
276}
277
278/// Search deferred tools by keyword query
279pub fn search_tools_with_keywords(
280    query: &str,
281    deferred_tools: &[&ToolDefinition],
282    max_results: usize,
283) -> Vec<String> {
284    let query_lower = query.to_lowercase().trim().to_string();
285
286    // Fast path: exact match on tool name
287    if let Some(exact) = deferred_tools
288        .iter()
289        .find(|t| t.name.to_lowercase() == query_lower)
290    {
291        return vec![exact.name.clone()];
292    }
293
294    // MCP prefix match
295    if query_lower.starts_with("mcp__") && query_lower.len() > 5 {
296        let matches: Vec<String> = deferred_tools
297            .iter()
298            .filter(|t| t.name.to_lowercase().starts_with(&query_lower))
299            .take(max_results)
300            .map(|t| t.name.clone())
301            .collect();
302        if !matches.is_empty() {
303            return matches;
304        }
305    }
306
307    let query_terms: Vec<&str> = query_lower
308        .split_whitespace()
309        .filter(|t| !t.is_empty())
310        .collect();
311
312    // Partition into required (+prefixed) and optional terms
313    let mut required_terms = Vec::new();
314    let mut optional_terms = Vec::new();
315
316    for term in &query_terms {
317        if term.starts_with('+') && term.len() > 1 {
318            required_terms.push(&term[1..]);
319        } else {
320            optional_terms.push(term);
321        }
322    }
323
324    let all_terms: Vec<&str> = if !required_terms.is_empty() {
325        let mut combined: Vec<&str> = required_terms.clone();
326        combined.extend(optional_terms.iter().map(|x| **x));
327        combined
328    } else {
329        optional_terms.iter().map(|x| **x).collect()
330    };
331
332    // Score each tool
333    let mut scored: Vec<(String, i32)> = deferred_tools
334        .iter()
335        .filter_map(|tool| {
336            let parsed = parse_tool_name(&tool.name);
337            let desc_lower = tool.description.to_lowercase();
338            let hint_lower = tool
339                .search_hint
340                .as_ref()
341                .map(|h| h.to_lowercase())
342                .unwrap_or_default();
343
344            // Pre-filter: if required terms, must match at least one
345            if !required_terms.is_empty() {
346                let matches_all = required_terms.iter().all(|&term| {
347                    parsed.parts.iter().any(|p| p == term || p.contains(term))
348                        || desc_lower.contains(term)
349                        || hint_lower.contains(term)
350                });
351                if !matches_all {
352                    return None;
353                }
354            }
355
356            let mut score = 0;
357            for &term in &all_terms {
358                // Exact part match
359                if parsed.parts.iter().any(|p| p == term) {
360                    score += if parsed.is_mcp { 12 } else { 10 };
361                } else if parsed.parts.iter().any(|p| p.contains(term)) {
362                    score += if parsed.is_mcp { 6 } else { 5 };
363                }
364
365                // Full name fallback
366                if score == 0 && parsed.full.contains(term) {
367                    score += 3;
368                }
369
370                // Search hint match
371                if !hint_lower.is_empty() && hint_lower.contains(term) {
372                    score += 4;
373                }
374
375                // Description match
376                if desc_lower.contains(term) {
377                    score += 2;
378                }
379            }
380
381            if score > 0 {
382                Some((tool.name.clone(), score))
383            } else {
384                None
385            }
386        })
387        .collect();
388
389    scored.sort_by(|a, b| b.1.cmp(&a.1));
390    scored
391        .into_iter()
392        .take(max_results)
393        .map(|(name, _)| name)
394        .collect()
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    fn make_tool(
402        name: &str,
403        should_defer: Option<bool>,
404        is_mcp: Option<bool>,
405        always_load: Option<bool>,
406    ) -> ToolDefinition {
407        let mut t = ToolDefinition::new(name, "", crate::types::ToolInputSchema::default());
408        t.should_defer = should_defer;
409        t.is_mcp = is_mcp;
410        t.always_load = always_load;
411        t
412    }
413
414    #[test]
415    fn test_is_deferred_tool_mcp() {
416        let tool = make_tool("mcp__github__pr", None, Some(true), None);
417        assert!(is_deferred_tool(&tool));
418    }
419
420    #[test]
421    fn test_is_deferred_tool_should_defer() {
422        let tool = make_tool("WebSearch", Some(true), None, None);
423        assert!(is_deferred_tool(&tool));
424    }
425
426    #[test]
427    fn test_is_deferred_tool_always_load() {
428        let tool = make_tool("Brief", Some(true), None, Some(true));
429        assert!(!is_deferred_tool(&tool));
430    }
431
432    #[test]
433    fn test_is_deferred_tool_tool_search() {
434        let mut tool = make_tool(TOOL_SEARCH_TOOL_NAME, Some(true), None, None);
435        // ToolSearch should never be deferred
436        assert!(!is_deferred_tool(&tool));
437    }
438
439    #[test]
440    fn test_deferred_tool_names() {
441        let tool1 = make_tool("Bash", None, None, None);
442        let tool2 = make_tool("WebSearch", Some(true), None, None);
443        let tool3 = make_tool("mcp__slack__send", None, Some(true), None);
444        let tool4 = make_tool("Read", None, None, None);
445        let tools = vec![tool1, tool2, tool3, tool4];
446        let deferred = get_deferred_tool_names(&tools);
447        assert_eq!(deferred, vec!["WebSearch", "mcp__slack__send"]);
448    }
449
450    #[test]
451    fn test_parse_tool_name_regular() {
452        let parts = parse_tool_name("Read");
453        // CamelCase splitting in Rust is basic - it won't perfectly split CamelCase
454        // The important thing is it handles MCP tools correctly
455        assert!(!parts.is_mcp);
456    }
457
458    #[test]
459    fn test_parse_tool_name_mcp() {
460        let parts = parse_tool_name("mcp__github__get_pr");
461        assert_eq!(parts.parts, vec!["github", "get", "pr"]);
462        assert!(parts.is_mcp);
463    }
464
465    #[test]
466    fn test_parse_query_select() {
467        let q = parse_tool_search_query("select:Read,Edit,Grep");
468        match q {
469            ToolSearchQuery::Select(tools) => {
470                assert_eq!(tools, vec!["Read", "Edit", "Grep"]);
471            }
472            _ => panic!("Expected Select query"),
473        }
474    }
475
476    #[test]
477    fn test_parse_query_keyword() {
478        let q = parse_tool_search_query("notebook jupyter");
479        match q {
480            ToolSearchQuery::Keyword(s) => {
481                assert_eq!(s, "notebook jupyter");
482            }
483            _ => panic!("Expected Keyword query"),
484        }
485    }
486
487    #[test]
488    fn test_search_tools_keyword() {
489        let tool1 = make_tool("WebSearch", Some(true), None, None);
490        let tool2 = make_tool("WebFetch", Some(true), None, None);
491        let tool3 = make_tool("Read", None, None, None);
492        let tools = vec![&tool1, &tool2, &tool3];
493        let results = search_tools_with_keywords("search", &tools, 5);
494        assert!(results.contains(&"WebSearch".to_string()));
495    }
496}