claude_agent/tools/search/
manager.rs

1//! Tool search manager for coordinating search operations.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7
8use super::engine::{SearchEngine, SearchMode};
9use super::index::{ToolIndex, ToolIndexEntry};
10use crate::mcp::{McpManager, McpToolDefinition, McpToolsetRegistry};
11use crate::types::ToolDefinition;
12
13#[derive(Debug, Clone)]
14pub struct ToolSearchConfig {
15    pub threshold: f64,
16    pub context_window: usize,
17    pub search_mode: SearchMode,
18    pub max_results: usize,
19    pub always_load: Vec<String>,
20}
21
22impl Default for ToolSearchConfig {
23    fn default() -> Self {
24        Self {
25            threshold: 0.10,
26            context_window: 200_000,
27            search_mode: SearchMode::Regex,
28            max_results: 5,
29            always_load: Vec::new(),
30        }
31    }
32}
33
34impl ToolSearchConfig {
35    pub fn threshold_tokens(&self) -> usize {
36        (self.context_window as f64 * self.threshold) as usize
37    }
38
39    pub fn with_threshold(mut self, threshold: f64) -> Self {
40        self.threshold = threshold.clamp(0.0, 1.0);
41        self
42    }
43
44    pub fn with_context_window(mut self, tokens: usize) -> Self {
45        self.context_window = tokens;
46        self
47    }
48
49    pub fn with_search_mode(mut self, mode: SearchMode) -> Self {
50        self.search_mode = mode;
51        self
52    }
53
54    pub fn with_always_load(mut self, tools: Vec<String>) -> Self {
55        self.always_load = tools;
56        self
57    }
58}
59
60pub struct ToolSearchManager {
61    config: ToolSearchConfig,
62    index: Arc<RwLock<ToolIndex>>,
63    definitions: Arc<RwLock<HashMap<String, McpToolDefinition>>>,
64    engine: SearchEngine,
65    toolset_registry: Arc<RwLock<McpToolsetRegistry>>,
66}
67
68impl ToolSearchManager {
69    pub fn new(config: ToolSearchConfig) -> Self {
70        let engine = SearchEngine::new(config.search_mode);
71        Self {
72            config,
73            index: Arc::new(RwLock::new(ToolIndex::new())),
74            definitions: Arc::new(RwLock::new(HashMap::new())),
75            engine,
76            toolset_registry: Arc::new(RwLock::new(McpToolsetRegistry::new())),
77        }
78    }
79
80    pub fn config(&self) -> &ToolSearchConfig {
81        &self.config
82    }
83
84    pub fn set_toolset_registry(&self, registry: McpToolsetRegistry) -> &Self {
85        if let Ok(mut guard) = self.toolset_registry.try_write() {
86            *guard = registry;
87        }
88        self
89    }
90
91    pub async fn build_index(&self, mcp_manager: &McpManager) {
92        let tools = mcp_manager.list_tools().await;
93
94        let mut index = self.index.write().await;
95        let mut definitions = self.definitions.write().await;
96
97        index.clear();
98        definitions.clear();
99
100        for (qualified_name, tool) in tools {
101            if let Some((server, _)) = crate::mcp::parse_mcp_name(&qualified_name) {
102                let entry = ToolIndexEntry::from_mcp_tool(server, &tool);
103                index.add(entry);
104                definitions.insert(qualified_name, tool);
105            }
106        }
107    }
108
109    pub async fn should_use_search(&self) -> bool {
110        let index = self.index.read().await;
111        index.total_tokens() > self.config.threshold_tokens()
112    }
113
114    pub async fn total_tokens(&self) -> usize {
115        self.index.read().await.total_tokens()
116    }
117
118    pub async fn tool_count(&self) -> usize {
119        self.index.read().await.len()
120    }
121
122    pub async fn prepare_tools(&self) -> PreparedTools {
123        let index = self.index.read().await;
124        let definitions = self.definitions.read().await;
125        let toolset_registry = self.toolset_registry.read().await;
126
127        let use_search = index.total_tokens() > self.config.threshold_tokens();
128        let mut immediate = Vec::new();
129        let mut deferred = Vec::new();
130
131        for entry in index.entries() {
132            let Some(def) = definitions.get(&entry.qualified_name) else {
133                continue;
134            };
135
136            let is_always_load = self.config.always_load.contains(&entry.qualified_name)
137                || self.config.always_load.contains(&entry.tool_name);
138
139            // always_load has highest priority - never defer these tools
140            if is_always_load {
141                let tool_def = ToolDefinition {
142                    name: entry.qualified_name.clone(),
143                    description: def.description.clone(),
144                    input_schema: def.input_schema.clone(),
145                    strict: None,
146                    defer_loading: None,
147                };
148                immediate.push(tool_def);
149                continue;
150            }
151
152            // Toolset config takes precedence over automatic threshold
153            let toolset_deferred =
154                toolset_registry.is_deferred(&entry.server_name, &entry.tool_name);
155
156            // Defer if: toolset explicitly requests OR threshold exceeded
157            let should_defer = toolset_deferred || use_search;
158
159            let tool_def = ToolDefinition {
160                name: entry.qualified_name.clone(),
161                description: def.description.clone(),
162                input_schema: def.input_schema.clone(),
163                strict: None,
164                defer_loading: if should_defer { Some(true) } else { None },
165            };
166
167            if should_defer {
168                deferred.push(tool_def);
169            } else {
170                immediate.push(tool_def);
171            }
172        }
173
174        PreparedTools {
175            use_search,
176            search_mode: self.config.search_mode,
177            immediate,
178            deferred,
179            total_tokens: index.total_tokens(),
180            threshold_tokens: self.config.threshold_tokens(),
181        }
182    }
183
184    pub async fn search(&self, query: &str) -> Vec<String> {
185        let index = self.index.read().await;
186        let hits = self.engine.search(&index, query, self.config.max_results);
187        hits.into_iter().map(|h| h.entry.qualified_name).collect()
188    }
189
190    pub async fn get_definition(&self, qualified_name: &str) -> Option<ToolDefinition> {
191        let definitions = self.definitions.read().await;
192        definitions.get(qualified_name).map(|def| ToolDefinition {
193            name: qualified_name.to_string(),
194            description: def.description.clone(),
195            input_schema: def.input_schema.clone(),
196            strict: None,
197            defer_loading: None,
198        })
199    }
200
201    pub async fn get_definitions(&self, names: &[String]) -> Vec<ToolDefinition> {
202        let definitions = self.definitions.read().await;
203        names
204            .iter()
205            .filter_map(|name| {
206                definitions.get(name).map(|def| ToolDefinition {
207                    name: name.clone(),
208                    description: def.description.clone(),
209                    input_schema: def.input_schema.clone(),
210                    strict: None,
211                    defer_loading: None,
212                })
213            })
214            .collect()
215    }
216}
217
218impl Default for ToolSearchManager {
219    fn default() -> Self {
220        Self::new(ToolSearchConfig::default())
221    }
222}
223
224#[derive(Debug)]
225pub struct PreparedTools {
226    pub use_search: bool,
227    pub search_mode: SearchMode,
228    pub immediate: Vec<ToolDefinition>,
229    pub deferred: Vec<ToolDefinition>,
230    pub total_tokens: usize,
231    pub threshold_tokens: usize,
232}
233
234impl PreparedTools {
235    pub fn all_tools(&self) -> impl Iterator<Item = &ToolDefinition> {
236        self.immediate.iter().chain(self.deferred.iter())
237    }
238
239    pub fn token_savings(&self) -> usize {
240        if self.use_search {
241            self.deferred
242                .iter()
243                .map(|t| t.estimated_tokens())
244                .sum::<usize>()
245        } else {
246            0
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_config_threshold_tokens() {
257        let config = ToolSearchConfig::default();
258        assert_eq!(config.threshold_tokens(), 20_000); // 10% of 200k
259    }
260
261    #[test]
262    fn test_config_builder() {
263        let config = ToolSearchConfig::default()
264            .with_threshold(0.05)
265            .with_context_window(100_000)
266            .with_search_mode(SearchMode::Bm25);
267
268        assert_eq!(config.threshold, 0.05);
269        assert_eq!(config.context_window, 100_000);
270        assert_eq!(config.search_mode, SearchMode::Bm25);
271        assert_eq!(config.threshold_tokens(), 5_000);
272    }
273
274    #[tokio::test]
275    async fn test_manager_creation() {
276        let manager = ToolSearchManager::default();
277        assert!(!manager.should_use_search().await);
278        assert_eq!(manager.total_tokens().await, 0);
279    }
280}