Skip to main content

brainwires_tools/
tool_search.rs

1//! Tool Search - Meta-tool for discovering available tools dynamically
2
3use regex::Regex;
4use serde::Deserialize;
5use serde_json::{Value, json};
6use std::collections::HashMap;
7
8use crate::ToolRegistry;
9use brainwires_core::{Tool, ToolContext, ToolInputSchema, ToolResult};
10
11/// Search mode for tool discovery
12#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Default)]
13#[serde(rename_all = "lowercase")]
14pub enum SearchMode {
15    /// Keyword-based search (default).
16    #[default]
17    Keyword,
18    /// Regex-based search.
19    Regex,
20    /// Semantic embedding-based search (requires `rag` feature).
21    Semantic,
22}
23
24/// Meta-tool for discovering available tools dynamically.
25pub struct ToolSearchTool;
26
27impl ToolSearchTool {
28    /// Return tool definitions for tool search.
29    pub fn get_tools() -> Vec<Tool> {
30        vec![Self::search_tools_tool()]
31    }
32
33    fn search_tools_tool() -> Tool {
34        let mut properties = HashMap::new();
35        properties.insert(
36            "query".to_string(),
37            json!({"type": "string", "description": "Search query to find relevant tools"}),
38        );
39        properties.insert("mode".to_string(), json!({"type": "string", "enum": ["keyword", "regex", "semantic"], "description": "Search mode: keyword (substring match), regex (pattern match), or semantic (embedding similarity, requires rag feature)", "default": "keyword"}));
40        properties.insert(
41            "include_deferred".to_string(),
42            json!({"type": "boolean", "description": "Include deferred tools", "default": true}),
43        );
44        properties.insert(
45            "limit".to_string(),
46            json!({"type": "integer", "description": "Maximum number of results to return (semantic mode only)", "default": 10}),
47        );
48        properties.insert(
49            "min_score".to_string(),
50            json!({"type": "number", "description": "Minimum similarity score 0.0-1.0 (semantic mode only)", "default": 0.3}),
51        );
52        Tool {
53            name: "search_tools".to_string(),
54            description: "Search for available tools by name or description.".to_string(),
55            input_schema: ToolInputSchema::object(properties, vec!["query".to_string()]),
56            requires_approval: false,
57            defer_loading: false,
58            ..Default::default()
59        }
60    }
61
62    /// Execute the tool search tool by name.
63    #[tracing::instrument(
64        name = "tool.execute",
65        skip(input, _context, registry),
66        fields(tool_name)
67    )]
68    pub fn execute(
69        tool_use_id: &str,
70        tool_name: &str,
71        input: &Value,
72        _context: &ToolContext,
73        registry: &ToolRegistry,
74    ) -> ToolResult {
75        let result = match tool_name {
76            "search_tools" => Self::search_tools(input, registry),
77            _ => Err(anyhow::anyhow!("Unknown tool search tool: {}", tool_name)),
78        };
79        match result {
80            Ok(output) => ToolResult::success(tool_use_id.to_string(), output),
81            Err(e) => ToolResult::error(
82                tool_use_id.to_string(),
83                format!("Tool search failed: {}", e),
84            ),
85        }
86    }
87
88    fn search_tools(input: &Value, registry: &ToolRegistry) -> anyhow::Result<String> {
89        #[derive(Deserialize)]
90        #[allow(dead_code)] // limit and min_score are used only with the `rag` feature
91        struct Input {
92            query: String,
93            #[serde(default)]
94            mode: SearchMode,
95            #[serde(default = "dt")]
96            include_deferred: bool,
97            #[serde(default = "default_limit")]
98            limit: usize,
99            #[serde(default = "default_min_score")]
100            min_score: f32,
101        }
102        fn dt() -> bool {
103            true
104        }
105        fn default_limit() -> usize {
106            10
107        }
108        fn default_min_score() -> f32 {
109            0.3
110        }
111
112        let params: Input = serde_json::from_value(input.clone())?;
113
114        // Handle semantic mode separately
115        #[cfg(feature = "rag")]
116        if params.mode == SearchMode::Semantic {
117            return Self::search_tools_semantic(
118                &params.query,
119                registry,
120                params.include_deferred,
121                params.limit,
122                params.min_score,
123            );
124        }
125
126        #[cfg(not(feature = "rag"))]
127        if params.mode == SearchMode::Semantic {
128            return Err(anyhow::anyhow!(
129                "Semantic search mode requires the 'rag' feature to be enabled. Use 'keyword' or 'regex' mode instead."
130            ));
131        }
132
133        if params.mode == SearchMode::Regex && params.query.len() > 200 {
134            return Err(anyhow::anyhow!(
135                "Regex pattern exceeds maximum length of 200 characters (got {})",
136                params.query.len()
137            ));
138        }
139
140        let regex =
141            if params.mode == SearchMode::Regex {
142                Some(Regex::new(&params.query).map_err(|e| {
143                    anyhow::anyhow!("Invalid regex pattern '{}': {}", params.query, e)
144                })?)
145            } else {
146                None
147            };
148
149        let query_lower = params.query.to_lowercase();
150        let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
151
152        let matching_tools: Vec<&Tool> = registry
153            .get_all()
154            .iter()
155            .filter(|tool| {
156                if tool.defer_loading && !params.include_deferred {
157                    return false;
158                }
159                let search_text = format!("{} {}", tool.name, tool.description);
160                match &regex {
161                    Some(re) => re.is_match(&search_text),
162                    None => {
163                        let name_lower = tool.name.to_lowercase();
164                        let desc_lower = tool.description.to_lowercase();
165                        query_terms
166                            .iter()
167                            .any(|term| name_lower.contains(term) || desc_lower.contains(term))
168                    }
169                }
170            })
171            .collect();
172
173        if matching_tools.is_empty() {
174            return Ok(format!(
175                "No tools found matching query: \"{}\"",
176                params.query
177            ));
178        }
179
180        let mut result = format!(
181            "Found {} tools matching \"{}\":\n\n",
182            matching_tools.len(),
183            params.query
184        );
185        for tool in matching_tools {
186            Self::format_tool(&mut result, tool, None);
187        }
188        Ok(result)
189    }
190
191    /// Format a single tool entry for output.
192    fn format_tool(result: &mut String, tool: &Tool, score: Option<f32>) {
193        result.push_str(&format!("## {}\n", tool.name));
194        if let Some(s) = score {
195            result.push_str(&format!("**Similarity:** {:.2}\n", s));
196        }
197        result.push_str(&format!("**Description:** {}\n", tool.description));
198        if let Some(props) = &tool.input_schema.properties {
199            result.push_str("**Parameters:**\n");
200            for (name, schema) in props {
201                let desc = schema
202                    .get("description")
203                    .and_then(|v| v.as_str())
204                    .unwrap_or("No description");
205                let ptype = schema
206                    .get("type")
207                    .and_then(|v| v.as_str())
208                    .unwrap_or("unknown");
209                result.push_str(&format!("  - `{}` ({}): {}\n", name, ptype, desc));
210            }
211        }
212        result.push('\n');
213    }
214
215    /// Semantic search using embedding similarity.
216    #[cfg(feature = "rag")]
217    fn search_tools_semantic(
218        query: &str,
219        registry: &ToolRegistry,
220        include_deferred: bool,
221        limit: usize,
222        min_score: f32,
223    ) -> anyhow::Result<String> {
224        use crate::tool_embedding::ToolEmbeddingIndex;
225        use std::sync::OnceLock;
226
227        // Cache the embedding index; rebuild if tool count changes.
228        static CACHED_INDEX: OnceLock<(usize, ToolEmbeddingIndex)> = OnceLock::new();
229
230        let tools: Vec<&Tool> = registry
231            .get_all()
232            .iter()
233            .filter(|t| include_deferred || !t.defer_loading)
234            .collect();
235
236        // Build tool pairs for embedding
237        let tool_pairs: Vec<(String, String)> = tools
238            .iter()
239            .map(|t| (t.name.clone(), t.description.clone()))
240            .collect();
241
242        // Use cached index if tool count hasn't changed, otherwise build new one.
243        // OnceLock means first call builds, subsequent calls reuse.
244        // If tools change (e.g., MCP tools added), the count won't match and we
245        // fall through to building a fresh index.
246        let index = CACHED_INDEX.get_or_init(|| {
247            let idx = ToolEmbeddingIndex::build(&tool_pairs)
248                .expect("Failed to build tool embedding index");
249            (tool_pairs.len(), idx)
250        });
251
252        // If tool count changed, we need a fresh index but can't replace OnceLock.
253        // In that case, build an ad-hoc index.
254        let search_results = if index.0 != tool_pairs.len() {
255            let fresh_index = ToolEmbeddingIndex::build(&tool_pairs)?;
256            fresh_index.search(query, limit, min_score)?
257        } else {
258            index.1.search(query, limit, min_score)?
259        };
260
261        if search_results.is_empty() {
262            return Ok(format!(
263                "No tools found semantically matching query: \"{}\" (min_score: {:.2})",
264                query, min_score
265            ));
266        }
267
268        let mut result = format!(
269            "Found {} tools semantically matching \"{}\":\n\n",
270            search_results.len(),
271            query
272        );
273
274        for (tool_name, score) in &search_results {
275            if let Some(tool) = registry.get(tool_name) {
276                Self::format_tool(&mut result, tool, Some(*score));
277            }
278        }
279
280        Ok(result)
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_get_tools() {
290        let tools = ToolSearchTool::get_tools();
291        assert_eq!(tools.len(), 1);
292        assert_eq!(tools[0].name, "search_tools");
293    }
294
295    #[test]
296    fn test_search_mode_default() {
297        let mode = SearchMode::default();
298        assert_eq!(mode, SearchMode::Keyword);
299    }
300}