claude_agent/tools/search/
manager.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7
8use super::engine::{SearchEngine, SearchMode};
9use super::index::{ToolIndex, ToolIndexEntry};
10use crate::mcp::{McpManager, McpToolDefinition, McpToolsetRegistry};
11use crate::types::ToolDefinition;
12
13#[derive(Debug, Clone)]
14pub struct ToolSearchConfig {
15 pub threshold: f64,
16 pub context_window: usize,
17 pub search_mode: SearchMode,
18 pub max_results: usize,
19 pub always_load: Vec<String>,
20}
21
22impl Default for ToolSearchConfig {
23 fn default() -> Self {
24 Self {
25 threshold: 0.10,
26 context_window: 200_000,
27 search_mode: SearchMode::Regex,
28 max_results: 5,
29 always_load: Vec::new(),
30 }
31 }
32}
33
34impl ToolSearchConfig {
35 pub fn threshold_tokens(&self) -> usize {
36 (self.context_window as f64 * self.threshold) as usize
37 }
38
39 pub fn with_threshold(mut self, threshold: f64) -> Self {
40 self.threshold = threshold.clamp(0.0, 1.0);
41 self
42 }
43
44 pub fn with_context_window(mut self, tokens: usize) -> Self {
45 self.context_window = tokens;
46 self
47 }
48
49 pub fn with_search_mode(mut self, mode: SearchMode) -> Self {
50 self.search_mode = mode;
51 self
52 }
53
54 pub fn with_always_load(mut self, tools: Vec<String>) -> Self {
55 self.always_load = tools;
56 self
57 }
58}
59
60pub struct ToolSearchManager {
61 config: ToolSearchConfig,
62 index: Arc<RwLock<ToolIndex>>,
63 definitions: Arc<RwLock<HashMap<String, McpToolDefinition>>>,
64 engine: SearchEngine,
65 toolset_registry: Arc<RwLock<McpToolsetRegistry>>,
66}
67
68impl ToolSearchManager {
69 pub fn new(config: ToolSearchConfig) -> Self {
70 let engine = SearchEngine::new(config.search_mode);
71 Self {
72 config,
73 index: Arc::new(RwLock::new(ToolIndex::new())),
74 definitions: Arc::new(RwLock::new(HashMap::new())),
75 engine,
76 toolset_registry: Arc::new(RwLock::new(McpToolsetRegistry::new())),
77 }
78 }
79
80 pub fn config(&self) -> &ToolSearchConfig {
81 &self.config
82 }
83
84 pub fn set_toolset_registry(&self, registry: McpToolsetRegistry) -> &Self {
85 if let Ok(mut guard) = self.toolset_registry.try_write() {
86 *guard = registry;
87 }
88 self
89 }
90
91 pub async fn build_index(&self, mcp_manager: &McpManager) {
92 let tools = mcp_manager.list_tools().await;
93
94 let mut index = self.index.write().await;
95 let mut definitions = self.definitions.write().await;
96
97 index.clear();
98 definitions.clear();
99
100 for (qualified_name, tool) in tools {
101 if let Some((server, _)) = crate::mcp::parse_mcp_name(&qualified_name) {
102 let entry = ToolIndexEntry::from_mcp_tool(server, &tool);
103 index.add(entry);
104 definitions.insert(qualified_name, tool);
105 }
106 }
107 }
108
109 pub async fn should_use_search(&self) -> bool {
110 let index = self.index.read().await;
111 index.total_tokens() > self.config.threshold_tokens()
112 }
113
114 pub async fn total_tokens(&self) -> usize {
115 self.index.read().await.total_tokens()
116 }
117
118 pub async fn tool_count(&self) -> usize {
119 self.index.read().await.len()
120 }
121
122 pub async fn prepare_tools(&self) -> PreparedTools {
123 let index = self.index.read().await;
124 let definitions = self.definitions.read().await;
125 let toolset_registry = self.toolset_registry.read().await;
126
127 let use_search = index.total_tokens() > self.config.threshold_tokens();
128 let mut immediate = Vec::new();
129 let mut deferred = Vec::new();
130
131 for entry in index.entries() {
132 let Some(def) = definitions.get(&entry.qualified_name) else {
133 continue;
134 };
135
136 let is_always_load = self.config.always_load.contains(&entry.qualified_name)
137 || self.config.always_load.contains(&entry.tool_name);
138
139 if is_always_load {
141 let tool_def = ToolDefinition {
142 name: entry.qualified_name.clone(),
143 description: def.description.clone(),
144 input_schema: def.input_schema.clone(),
145 strict: None,
146 defer_loading: None,
147 };
148 immediate.push(tool_def);
149 continue;
150 }
151
152 let toolset_deferred =
154 toolset_registry.is_deferred(&entry.server_name, &entry.tool_name);
155
156 let should_defer = toolset_deferred || use_search;
158
159 let tool_def = ToolDefinition {
160 name: entry.qualified_name.clone(),
161 description: def.description.clone(),
162 input_schema: def.input_schema.clone(),
163 strict: None,
164 defer_loading: if should_defer { Some(true) } else { None },
165 };
166
167 if should_defer {
168 deferred.push(tool_def);
169 } else {
170 immediate.push(tool_def);
171 }
172 }
173
174 PreparedTools {
175 use_search,
176 search_mode: self.config.search_mode,
177 immediate,
178 deferred,
179 total_tokens: index.total_tokens(),
180 threshold_tokens: self.config.threshold_tokens(),
181 }
182 }
183
184 pub async fn search(&self, query: &str) -> Vec<String> {
185 let index = self.index.read().await;
186 let hits = self.engine.search(&index, query, self.config.max_results);
187 hits.into_iter().map(|h| h.entry.qualified_name).collect()
188 }
189
190 pub async fn get_definition(&self, qualified_name: &str) -> Option<ToolDefinition> {
191 let definitions = self.definitions.read().await;
192 definitions.get(qualified_name).map(|def| ToolDefinition {
193 name: qualified_name.to_string(),
194 description: def.description.clone(),
195 input_schema: def.input_schema.clone(),
196 strict: None,
197 defer_loading: None,
198 })
199 }
200
201 pub async fn get_definitions(&self, names: &[String]) -> Vec<ToolDefinition> {
202 let definitions = self.definitions.read().await;
203 names
204 .iter()
205 .filter_map(|name| {
206 definitions.get(name).map(|def| ToolDefinition {
207 name: name.clone(),
208 description: def.description.clone(),
209 input_schema: def.input_schema.clone(),
210 strict: None,
211 defer_loading: None,
212 })
213 })
214 .collect()
215 }
216}
217
218impl Default for ToolSearchManager {
219 fn default() -> Self {
220 Self::new(ToolSearchConfig::default())
221 }
222}
223
224#[derive(Debug)]
225pub struct PreparedTools {
226 pub use_search: bool,
227 pub search_mode: SearchMode,
228 pub immediate: Vec<ToolDefinition>,
229 pub deferred: Vec<ToolDefinition>,
230 pub total_tokens: usize,
231 pub threshold_tokens: usize,
232}
233
234impl PreparedTools {
235 pub fn all_tools(&self) -> impl Iterator<Item = &ToolDefinition> {
236 self.immediate.iter().chain(self.deferred.iter())
237 }
238
239 pub fn token_savings(&self) -> usize {
240 if self.use_search {
241 self.deferred
242 .iter()
243 .map(|t| t.estimated_tokens())
244 .sum::<usize>()
245 } else {
246 0
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_config_threshold_tokens() {
257 let config = ToolSearchConfig::default();
258 assert_eq!(config.threshold_tokens(), 20_000); }
260
261 #[test]
262 fn test_config_builder() {
263 let config = ToolSearchConfig::default()
264 .with_threshold(0.05)
265 .with_context_window(100_000)
266 .with_search_mode(SearchMode::Bm25);
267
268 assert_eq!(config.threshold, 0.05);
269 assert_eq!(config.context_window, 100_000);
270 assert_eq!(config.search_mode, SearchMode::Bm25);
271 assert_eq!(config.threshold_tokens(), 5_000);
272 }
273
274 #[tokio::test]
275 async fn test_manager_creation() {
276 let manager = ToolSearchManager::default();
277 assert!(!manager.should_use_search().await);
278 assert_eq!(manager.total_tokens().await, 0);
279 }
280}