use crate::mcp::McpToolInfo;
use anyhow::Result;
use serde_json::Value;
use std::cmp::Ordering;
use std::sync::Arc;
use tracing::{debug, info};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DetailLevel {
NameOnly,
NameAndDescription,
Full,
}
impl DetailLevel {
pub fn as_str(&self) -> &'static str {
match self {
Self::NameOnly => "name-only",
Self::NameAndDescription => "name-and-description",
Self::Full => "full",
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ToolDiscoveryResult {
pub name: String,
pub provider: String,
pub description: String,
pub relevance_score: f32,
pub input_schema: Option<Value>,
}
impl ToolDiscoveryResult {
pub fn to_json(&self, detail_level: DetailLevel) -> Value {
match detail_level {
DetailLevel::NameOnly => serde_json::json!({
"name": self.name,
"provider": self.provider,
}),
DetailLevel::NameAndDescription => serde_json::json!({
"name": self.name,
"provider": self.provider,
"description": self.description,
}),
DetailLevel::Full => serde_json::json!({
"name": self.name,
"provider": self.provider,
"description": self.description,
"input_schema": self.input_schema,
}),
}
}
}
pub struct ToolDiscovery {
mcp_client: Arc<dyn crate::mcp::McpToolExecutor>,
}
impl ToolDiscovery {
pub fn new(mcp_client: Arc<dyn crate::mcp::McpToolExecutor>) -> Self {
Self { mcp_client }
}
pub async fn search_tools(
&self,
keyword: &str,
detail_level: DetailLevel,
) -> Result<Vec<ToolDiscoveryResult>> {
let tools = self.mcp_client.list_mcp_tools().await?;
debug!(
keyword = keyword,
count = tools.len(),
"Searching tools for keyword"
);
let mut results = Vec::with_capacity(tools.len() / 4);
for tool in tools {
let relevance_score = self.calculate_relevance(&tool, keyword);
if relevance_score <= 0.0 {
continue;
}
let input_schema = match detail_level {
DetailLevel::Full => Some(tool.input_schema.clone()),
_ => None,
};
results.push(ToolDiscoveryResult {
name: tool.name.clone(),
provider: tool.provider.clone(),
description: tool.description.clone(),
relevance_score,
input_schema,
});
}
results.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(Ordering::Equal)
});
let total_results = results.len();
if total_results > 5 {
info!(
keyword = keyword,
matched = total_results,
displayed = 5,
overflow = total_results - 5,
detail_level = detail_level.as_str(),
"Tool search completed with overflow"
);
results.truncate(5);
} else {
info!(
keyword = keyword,
matched = total_results,
detail_level = detail_level.as_str(),
"Tool search completed"
);
}
Ok(results)
}
pub async fn get_tool_detail(&self, tool_name: &str) -> Result<Option<ToolDiscoveryResult>> {
let tools = self.mcp_client.list_mcp_tools().await?;
for tool in tools {
if tool.name.eq_ignore_ascii_case(tool_name) {
return Ok(Some(ToolDiscoveryResult {
name: tool.name.clone(),
provider: tool.provider.clone(),
description: tool.description.clone(),
relevance_score: 1.0,
input_schema: Some(tool.input_schema),
}));
}
}
Ok(None)
}
pub async fn list_tools_by_provider(&self) -> Result<Vec<(String, Vec<ToolDiscoveryResult>)>> {
let tools = self.mcp_client.list_mcp_tools().await?;
let mut by_provider: rustc_hash::FxHashMap<String, Vec<ToolDiscoveryResult>> =
rustc_hash::FxHashMap::default();
for tool in tools {
let result = ToolDiscoveryResult {
name: tool.name.clone(),
provider: tool.provider.clone(),
description: tool.description.clone(),
relevance_score: 1.0,
input_schema: None,
};
by_provider
.entry(tool.provider.clone())
.or_default()
.push(result);
}
let mut result: Vec<(String, Vec<ToolDiscoveryResult>)> = by_provider
.into_iter()
.map(|(provider, mut tools)| {
tools.sort_by(|a, b| a.name.cmp(&b.name));
(provider, tools)
})
.collect();
result.sort_by(|a, b| a.0.cmp(&b.0));
Ok(result)
}
fn calculate_relevance(&self, tool: &McpToolInfo, keyword: &str) -> f32 {
let keyword_lower = keyword.to_lowercase();
if tool.name.eq_ignore_ascii_case(keyword) {
return 1.0;
}
if tool.name.to_lowercase().contains(&keyword_lower) {
return 0.8;
}
if tool.description.to_lowercase().contains(&keyword_lower) {
return 0.6;
}
let name_fuzzy = self.fuzzy_score(&tool.name.to_lowercase(), &keyword_lower);
if name_fuzzy > 0.3 {
return 0.5 * name_fuzzy;
}
let desc_fuzzy = self.fuzzy_score(&tool.description.to_lowercase(), &keyword_lower);
if desc_fuzzy > 0.3 {
return 0.3 * desc_fuzzy;
}
0.0
}
fn fuzzy_score(&self, haystack: &str, needle: &str) -> f32 {
if needle.is_empty() {
return 1.0;
}
if haystack.is_empty() {
return 0.0;
}
let mut score = 0.0;
let mut haystack_idx = 0;
for needle_char in needle.chars() {
if let Some(pos) = haystack[haystack_idx..].find(needle_char) {
haystack_idx += pos + 1;
score += 1.0;
} else {
return 0.0;
}
}
score / needle.len() as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fuzzy_score_exact_match() {
let discovery = ToolDiscovery::new(Arc::new(MockMcpClient));
assert_eq!(discovery.fuzzy_score("read_file", "read_file"), 1.0);
}
#[test]
fn fuzzy_score_partial_match() {
let discovery = ToolDiscovery::new(Arc::new(MockMcpClient));
let score = discovery.fuzzy_score("read_file_contents", "read");
assert!(score > 0.5 && score <= 1.0);
}
#[test]
fn fuzzy_score_no_match() {
let discovery = ToolDiscovery::new(Arc::new(MockMcpClient));
assert_eq!(discovery.fuzzy_score("read_file", "xyz"), 0.0);
}
struct MockMcpClient;
impl Default for MockMcpClient {
fn default() -> Self {
Self
}
}
#[async_trait::async_trait]
impl crate::mcp::McpToolExecutor for MockMcpClient {
async fn execute_mcp_tool(&self, _tool_name: &str, _args: &Value) -> Result<Value> {
Ok(Value::Null)
}
async fn list_mcp_tools(&self) -> Result<Vec<McpToolInfo>> {
Ok(vec![])
}
async fn has_mcp_tool(&self, _tool_name: &str) -> Result<bool> {
Ok(false)
}
fn get_status(&self) -> crate::mcp::McpClientStatus {
crate::mcp::McpClientStatus {
enabled: true,
provider_count: 0,
active_connections: 0,
configured_providers: vec![],
}
}
}
}