Skip to main content

mcp_proxy/
discovery.rs

1//! BM25-based tool discovery and search.
2//!
3//! When `tool_discovery = true` is set in the proxy config, this module
4//! indexes all tools from all backends using jpx-engine's BM25 search and
5//! exposes discovery tools under the `proxy/` namespace:
6//!
7//! - `proxy/search_tools` -- Full-text search across tool names, descriptions, and tags
8//! - `proxy/similar_tools` -- Find tools related to a given tool
9//! - `proxy/tool_categories` -- Browse tools by backend category
10
11use std::sync::Arc;
12
13use jpx_engine::{
14    CategorySummary, DiscoveryRegistry, DiscoverySpec, ParamSpec, ServerInfo, ToolQueryResult,
15    ToolSpec,
16};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use tokio::sync::RwLock;
20use tower_mcp::proxy::McpProxy;
21use tower_mcp::{CallToolResult, NoParams, ToolBuilder, ToolDefinition};
22
23/// Shared discovery index, wrapped for concurrent access from tool handlers.
24pub type SharedDiscoveryIndex = Arc<RwLock<DiscoveryRegistry>>;
25
26/// Build a discovery index from the proxy's current tool list.
27///
28/// Sends a `ListTools` request through the proxy to collect all registered
29/// tools, then indexes them using jpx-engine's BM25 search.
30pub async fn build_index(proxy: &mut McpProxy, separator: &str) -> SharedDiscoveryIndex {
31    use tower::Service;
32    use tower_mcp::protocol::{ListToolsParams, McpRequest, McpResponse, RequestId};
33    use tower_mcp::router::{Extensions, RouterRequest};
34
35    let req = RouterRequest {
36        id: RequestId::Number(0),
37        inner: McpRequest::ListTools(ListToolsParams::default()),
38        extensions: Extensions::new(),
39    };
40
41    let tools = match proxy.call(req).await {
42        Ok(resp) => match resp.inner {
43            Ok(McpResponse::ListTools(result)) => result.tools,
44            _ => {
45                tracing::warn!("Failed to list tools for discovery indexing");
46                vec![]
47            }
48        },
49        Err(_) => vec![],
50    };
51
52    let mut registry = DiscoveryRegistry::new();
53    index_tools(&mut registry, &tools, separator);
54
55    tracing::info!(tools_indexed = tools.len(), "Built tool discovery index");
56
57    Arc::new(RwLock::new(registry))
58}
59
60/// Re-index all tools into an existing shared discovery index.
61///
62/// Called after hot reload adds, removes, or replaces backends to keep
63/// the search index in sync with the proxy's current tool set.
64pub async fn reindex(index: &SharedDiscoveryIndex, proxy: &mut McpProxy, separator: &str) {
65    use tower::Service;
66    use tower_mcp::protocol::{ListToolsParams, McpRequest, McpResponse, RequestId};
67    use tower_mcp::router::{Extensions, RouterRequest};
68
69    let req = RouterRequest {
70        id: RequestId::Number(0),
71        inner: McpRequest::ListTools(ListToolsParams::default()),
72        extensions: Extensions::new(),
73    };
74
75    let tools = match proxy.call(req).await {
76        Ok(resp) => match resp.inner {
77            Ok(McpResponse::ListTools(result)) => result.tools,
78            _ => vec![],
79        },
80        Err(_) => vec![],
81    };
82
83    let mut registry = DiscoveryRegistry::new();
84    index_tools(&mut registry, &tools, separator);
85
86    let mut guard = index.write().await;
87    *guard = registry;
88
89    tracing::info!(tools_indexed = tools.len(), "Re-indexed tool discovery");
90}
91
92/// Index MCP tool definitions into the discovery registry.
93///
94/// Groups tools by backend namespace (derived from the separator) and registers
95/// each group as a discovery "server" with its tools.
96fn index_tools(registry: &mut DiscoveryRegistry, tools: &[ToolDefinition], separator: &str) {
97    // Group tools by backend namespace
98    let mut by_namespace: std::collections::HashMap<String, Vec<&ToolDefinition>> =
99        std::collections::HashMap::new();
100
101    for tool in tools {
102        let namespace = tool
103            .name
104            .split_once(separator)
105            .map(|(ns, _)| ns.to_string())
106            .unwrap_or_else(|| "default".to_string());
107        by_namespace.entry(namespace).or_default().push(tool);
108    }
109
110    for (namespace, ns_tools) in &by_namespace {
111        let tool_specs: Vec<ToolSpec> = ns_tools
112            .iter()
113            .map(|t| tool_definition_to_spec(t, separator))
114            .collect();
115
116        let spec = DiscoverySpec {
117            schema: None,
118            server: ServerInfo {
119                name: namespace.clone(),
120                version: None,
121                description: None,
122            },
123            tools: tool_specs,
124            categories: std::collections::HashMap::new(),
125        };
126
127        registry.register(spec, true);
128    }
129}
130
131/// Convert an MCP ToolDefinition to a jpx ToolSpec for indexing.
132fn tool_definition_to_spec(tool: &ToolDefinition, separator: &str) -> ToolSpec {
133    // Extract the local tool name (without namespace prefix)
134    let local_name = tool
135        .name
136        .split_once(separator)
137        .map(|(_, name)| name.to_string())
138        .unwrap_or_else(|| tool.name.clone());
139
140    // Extract parameter names from input schema
141    let params = extract_params(&tool.input_schema);
142
143    // Extract tags from annotations if available
144    let mut tags = Vec::new();
145    if let Some(annotations) = &tool.annotations {
146        if annotations.destructive_hint {
147            tags.push("destructive".to_string());
148        }
149        if annotations.read_only_hint {
150            tags.push("read-only".to_string());
151        }
152        if annotations.idempotent_hint {
153            tags.push("idempotent".to_string());
154        }
155        if annotations.open_world_hint {
156            tags.push("open-world".to_string());
157        }
158    }
159
160    // Extract category from namespace
161    let category = tool
162        .name
163        .split_once(separator)
164        .map(|(ns, _)| ns.to_string());
165
166    ToolSpec {
167        name: local_name,
168        aliases: vec![],
169        category,
170        subcategory: None,
171        tags,
172        summary: tool.description.clone(),
173        description: tool.description.clone(),
174        params,
175        returns: None,
176        examples: vec![],
177        related: vec![],
178        since: None,
179        stability: None,
180    }
181}
182
183/// Extract parameter specs from a JSON Schema input_schema.
184fn extract_params(schema: &serde_json::Value) -> Vec<ParamSpec> {
185    let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) else {
186        return vec![];
187    };
188    let required: std::collections::HashSet<&str> = schema
189        .get("required")
190        .and_then(|r| r.as_array())
191        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
192        .unwrap_or_default();
193
194    properties
195        .iter()
196        .map(|(name, prop)| ParamSpec {
197            name: name.clone(),
198            param_type: prop.get("type").and_then(|t| t.as_str()).map(String::from),
199            required: required.contains(name.as_str()),
200            description: prop
201                .get("description")
202                .and_then(|d| d.as_str())
203                .map(String::from),
204            enum_values: None,
205            default: None,
206        })
207        .collect()
208}
209
210// ---------------------------------------------------------------------------
211// Discovery tool handlers
212// ---------------------------------------------------------------------------
213
214#[derive(Debug, Deserialize, JsonSchema)]
215struct SearchInput {
216    /// Search query (e.g. "read file", "database query", "math operations")
217    query: String,
218    /// Maximum number of results to return (default: 10)
219    #[serde(default = "default_top_k")]
220    top_k: usize,
221}
222
223fn default_top_k() -> usize {
224    10
225}
226
227#[derive(Debug, Deserialize, JsonSchema)]
228struct SimilarInput {
229    /// Tool ID to find similar tools for (e.g. "math:add")
230    tool_id: String,
231    /// Maximum number of results to return (default: 5)
232    #[serde(default = "default_similar_k")]
233    top_k: usize,
234}
235
236fn default_similar_k() -> usize {
237    5
238}
239
240#[derive(Serialize)]
241struct SearchResultEntry {
242    id: String,
243    server: String,
244    name: String,
245    description: Option<String>,
246    score: f64,
247    tags: Vec<String>,
248    category: Option<String>,
249}
250
251impl From<ToolQueryResult> for SearchResultEntry {
252    fn from(r: ToolQueryResult) -> Self {
253        Self {
254            id: r.id,
255            server: r.server,
256            name: r.tool.name,
257            description: r.tool.description,
258            score: r.score,
259            tags: r.tool.tags,
260            category: r.tool.category,
261        }
262    }
263}
264
265#[derive(Serialize)]
266struct CategoriesResult {
267    categories: Vec<CategorySummary>,
268    total_categories: usize,
269}
270
271/// Build the discovery tools and return them for inclusion in the admin router.
272pub fn build_discovery_tools(index: SharedDiscoveryIndex) -> Vec<tower_mcp::Tool> {
273    let index_for_search = Arc::clone(&index);
274    let search_tools = ToolBuilder::new("search_tools")
275        .description(
276            "Search for tools across all backends using BM25 full-text search. \
277             Searches tool names, descriptions, parameters, and tags.",
278        )
279        .handler(move |input: SearchInput| {
280            let idx = Arc::clone(&index_for_search);
281            async move {
282                let registry = idx.read().await;
283                let results = registry.query(&input.query, input.top_k);
284                let entries: Vec<SearchResultEntry> =
285                    results.into_iter().map(SearchResultEntry::from).collect();
286                Ok(CallToolResult::text(
287                    serde_json::to_string_pretty(&entries).unwrap(),
288                ))
289            }
290        })
291        .build();
292
293    let index_for_similar = Arc::clone(&index);
294    let similar_tools = ToolBuilder::new("similar_tools")
295        .description(
296            "Find tools similar to a given tool. Uses BM25 similarity based on \
297             shared terms in descriptions, parameters, and tags.",
298        )
299        .handler(move |input: SimilarInput| {
300            let idx = Arc::clone(&index_for_similar);
301            async move {
302                let registry = idx.read().await;
303                let results = registry.similar(&input.tool_id, input.top_k);
304                let entries: Vec<SearchResultEntry> =
305                    results.into_iter().map(SearchResultEntry::from).collect();
306                Ok(CallToolResult::text(
307                    serde_json::to_string_pretty(&entries).unwrap(),
308                ))
309            }
310        })
311        .build();
312
313    let index_for_categories = Arc::clone(&index);
314    let tool_categories = ToolBuilder::new("tool_categories")
315        .description(
316            "List all tool categories (backend namespaces) with tool counts. \
317             Useful for browsing available capabilities by domain.",
318        )
319        .handler(move |_: NoParams| {
320            let idx = Arc::clone(&index_for_categories);
321            async move {
322                let registry = idx.read().await;
323                let categories = registry.list_categories();
324                let mut cats: Vec<CategorySummary> = categories.into_values().collect();
325                cats.sort_by(|a, b| b.tool_count.cmp(&a.tool_count));
326                let result = CategoriesResult {
327                    total_categories: cats.len(),
328                    categories: cats,
329                };
330                Ok(CallToolResult::text(
331                    serde_json::to_string_pretty(&result).unwrap(),
332                ))
333            }
334        })
335        .build();
336
337    vec![search_tools, similar_tools, tool_categories]
338}