use anyhow::Context;
use futures::future::join_all;
use rmcp::model::Tool;
use rmcp::ServiceExt;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::process::Stdio;
use std::time::Duration;
use tokio::process::Command;
use tokio::time::timeout;
pub mod error;
pub mod search;
pub use error::ToolSearchError;
pub use search::{load_servers, simple_search, SearchBuilder};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub name: String,
pub transport: TransportConfig,
}
impl ServerConfig {
pub fn validate(&self) -> Result<(), String> {
if self.name.is_empty() {
return Err("Server name cannot be empty".to_string());
}
match &self.transport {
TransportConfig::Stdio { command, .. } => {
if command.is_empty() {
return Err(format!("Command cannot be empty for server: {}", self.name));
}
}
TransportConfig::Sse { url, .. } => {
if url.is_empty() {
return Err(format!("URL cannot be empty for server: {}", self.name));
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(format!("Invalid URL format for server {}: {}", self.name, url));
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum TransportConfig {
#[serde(rename = "stdio")]
Stdio {
command: String,
args: Vec<String>,
#[serde(default)]
env: HashMap<String, String>,
},
#[serde(rename = "sse")]
Sse {
url: String,
#[serde(default)]
headers: HashMap<String, String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSearchMatch {
pub server_name: String,
pub tool: Tool,
}
impl ToolSearchMatch {
pub fn tool_name(&self) -> &str {
self.tool.name.as_ref()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SortOrder {
ServerThenTool,
ToolThenServer,
None,
}
#[derive(Debug, Clone)]
pub struct SearchOptions {
pub timeout: Option<Duration>,
pub sort_order: SortOrder,
pub continue_on_error: bool,
pub max_results: Option<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchMode {
Substring,
Regex,
Keywords,
WordBoundary,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SearchFields {
pub name: bool,
pub title: bool,
pub description: bool,
pub input_schema: bool,
}
impl Default for SearchFields {
fn default() -> Self {
Self {
name: true,
title: true,
description: true,
input_schema: false,
}
}
}
#[derive(Debug, Clone)]
pub struct SearchCriteria {
pub query: Option<String>,
pub name: Option<String>,
pub mode: SearchMode,
pub fields: SearchFields,
pub case_sensitive: bool,
pub min_description_length: Option<usize>,
pub keywords: Vec<String>,
#[allow(clippy::type_complexity)]
regex: Option<Result<Regex, regex::Error>>,
}
impl SearchCriteria {
pub fn with_query(query: String) -> Self {
Self {
query: Some(query),
name: None,
mode: SearchMode::Substring,
fields: SearchFields::default(),
case_sensitive: false,
min_description_length: None,
keywords: vec![],
regex: None,
}
}
pub fn with_name(name: String) -> Self {
Self {
query: None,
name: Some(name),
mode: SearchMode::Substring,
fields: SearchFields::default(),
case_sensitive: false,
min_description_length: None,
keywords: vec![],
regex: None,
}
}
pub fn with_regex(pattern: String) -> Self {
let regex = Regex::new(&pattern);
Self {
query: Some(pattern),
name: None,
mode: SearchMode::Regex,
fields: SearchFields::default(),
case_sensitive: false,
min_description_length: None,
keywords: vec![],
regex: Some(regex),
}
}
pub fn with_keywords(keywords: Vec<String>) -> Self {
Self {
query: None,
name: None,
mode: SearchMode::Keywords,
fields: SearchFields::default(),
case_sensitive: false,
min_description_length: None,
keywords,
regex: None,
}
}
pub fn match_all() -> Self {
Self {
query: None,
name: None,
mode: SearchMode::Substring,
fields: SearchFields::default(),
case_sensitive: false,
min_description_length: None,
keywords: vec![],
regex: None,
}
}
pub fn with_mode(mut self, mode: SearchMode) -> Self {
self.mode = mode;
if mode == SearchMode::Regex {
if let Some(ref query) = self.query {
self.regex = Some(Regex::new(query));
}
}
self
}
pub fn with_fields(mut self, fields: SearchFields) -> Self {
self.fields = fields;
self
}
pub fn case_sensitive(mut self, sensitive: bool) -> Self {
self.case_sensitive = sensitive;
self
}
fn extract_schema_text(schema: &Value) -> String {
let mut text = String::new();
if let Some(obj) = schema.as_object() {
if let Some(properties) = obj.get("properties") {
if let Some(props_obj) = properties.as_object() {
for key in props_obj.keys() {
text.push_str(key);
text.push(' ');
}
}
}
if let Some(desc) = obj.get("description").and_then(|v| v.as_str()) {
text.push_str(desc);
text.push(' ');
}
for value in obj.values() {
if value.is_object() {
text.push_str(&Self::extract_schema_text(value));
} else if let Some(s) = value.as_str() {
text.push_str(s);
text.push(' ');
}
}
}
text
}
fn text_matches(&self, text: &str) -> bool {
let search_text = if self.case_sensitive {
text.to_string()
} else {
text.to_lowercase()
};
match self.mode {
SearchMode::Substring => {
let query = if self.case_sensitive {
self.query.as_ref().unwrap().clone()
} else {
self.query.as_ref().unwrap().to_lowercase()
};
search_text.contains(&query)
}
SearchMode::Regex => {
if let Some(ref regex_result) = self.regex {
match regex_result {
Ok(regex) => regex.is_match(text),
Err(_) => false,
}
} else if let Some(ref query) = self.query {
match Regex::new(query) {
Ok(regex) => regex.is_match(text),
Err(_) => false,
}
} else {
false
}
}
SearchMode::Keywords => {
let keywords = if self.case_sensitive {
self.keywords.clone()
} else {
self.keywords.iter().map(|k| k.to_lowercase()).collect()
};
keywords.iter().all(|keyword| search_text.contains(keyword))
}
SearchMode::WordBoundary => {
let query = if self.case_sensitive {
self.query.as_ref().unwrap().clone()
} else {
self.query.as_ref().unwrap().to_lowercase()
};
let pattern = format!(r"\b{}\b", regex::escape(&query));
match Regex::new(&pattern) {
Ok(regex) => {
if self.case_sensitive {
regex.is_match(text)
} else {
regex.is_match(&search_text)
}
}
Err(_) => search_text.contains(&query),
}
}
}
}
pub fn matches(&self, tool: &Tool) -> bool {
if let Some(ref name) = self.name {
let tool_name: &str = tool.name.as_ref();
return if self.case_sensitive {
tool_name == name
} else {
tool_name.eq_ignore_ascii_case(name)
};
}
if let Some(min_len) = self.min_description_length {
if tool
.description
.as_ref()
.map(|d| d.len() < min_len)
.unwrap_or(true)
{
return false;
}
}
if self.query.is_none() && self.keywords.is_empty() {
return true;
}
let mut searchable_texts = Vec::new();
if self.fields.name {
searchable_texts.push(("name", tool.name.as_ref().to_string()));
}
if self.fields.title {
if let Some(ref title) = tool.title {
searchable_texts.push(("title", title.to_string()));
}
}
if self.fields.description {
if let Some(ref desc) = tool.description {
searchable_texts.push(("description", desc.as_ref().to_string()));
}
}
if self.fields.input_schema {
let schema_value: Value = serde_json::to_value(&*tool.input_schema)
.unwrap_or(Value::Object(serde_json::Map::new()));
let schema_text = Self::extract_schema_text(&schema_value);
if !schema_text.is_empty() {
searchable_texts.push(("input_schema", schema_text));
}
}
for (_field_name, text) in searchable_texts {
if self.text_matches(&text) {
return true;
}
}
false
}
}
async fn connect_to_server(
config: &ServerConfig,
) -> Result<rmcp::service::RunningService<rmcp::RoleClient, ()>, ToolSearchError> {
match &config.transport {
TransportConfig::Stdio { command, args, env } => {
let mut cmd = Command::new(command);
cmd.args(args);
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
cmd.envs(env);
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to spawn command: {}", command))?;
let stdin = child.stdin.take().ok_or_else(|| {
ToolSearchError::Connection("Failed to get stdin from child process".to_string())
})?;
let stdout = child.stdout.take().ok_or_else(|| {
ToolSearchError::Connection("Failed to get stdout from child process".to_string())
})?;
let service = ().serve((stdout, stdin))
.await
.map_err(|e| ToolSearchError::Connection(format!("Failed to initialize client: {}", e)))?;
Ok(service)
}
TransportConfig::Sse { url, headers: _ } => {
Err(ToolSearchError::UnsupportedTransport(
format!("SSE transport not yet implemented for URL: {}", url),
))
}
}
}
pub async fn list_tools_from_server(
config: &ServerConfig,
) -> Result<Vec<Tool>, ToolSearchError> {
list_tools_from_server_with_timeout(config, None).await
}
pub async fn list_tools_from_server_with_timeout(
config: &ServerConfig,
timeout_duration: Option<Duration>,
) -> Result<Vec<Tool>, ToolSearchError> {
let connect_future = connect_to_server(config);
let service = if let Some(timeout_dur) = timeout_duration {
timeout(timeout_dur, connect_future)
.await
.map_err(|_| ToolSearchError::Connection(format!(
"Connection timeout after {:?} for server: {}",
timeout_dur, config.name
)))?
} else {
connect_future.await
}?;
let peer = service.peer();
let mut tools = Vec::new();
let mut cursor = None;
loop {
let list_future = peer.list_tools(Some(rmcp::model::PaginatedRequestParam { cursor }));
let result = if let Some(timeout_dur) = timeout_duration {
timeout(timeout_dur, list_future)
.await
.map_err(|_| ToolSearchError::Connection(format!(
"List tools timeout after {:?} for server: {}",
timeout_dur, config.name
)))?
} else {
list_future.await
}?;
tools.extend(result.tools);
if result.next_cursor.is_some() {
cursor = result.next_cursor;
} else {
break;
}
}
Ok(tools)
}
impl Default for SearchOptions {
fn default() -> Self {
Self {
timeout: Some(Duration::from_secs(30)),
sort_order: SortOrder::ServerThenTool,
continue_on_error: true,
max_results: None,
}
}
}
pub async fn search_tools(
servers: &[ServerConfig],
criteria: &SearchCriteria,
) -> Result<Vec<ToolSearchMatch>, ToolSearchError> {
search_tools_with_options(servers, criteria, &SearchOptions::default()).await
}
pub async fn search_tools_with_options(
servers: &[ServerConfig],
criteria: &SearchCriteria,
options: &SearchOptions,
) -> Result<Vec<ToolSearchMatch>, ToolSearchError> {
for server in servers {
if let Err(e) = server.validate() {
if !options.continue_on_error {
return Err(ToolSearchError::Connection(e));
}
eprintln!("Warning: Invalid server configuration {}: {}", server.name, e);
}
}
let server_futures: Vec<_> = servers
.iter()
.filter_map(|server_config| {
if server_config.validate().is_err() && options.continue_on_error {
return None;
}
let config = server_config.clone();
let timeout_dur = options.timeout;
Some(async move {
let result = list_tools_from_server_with_timeout(&config, timeout_dur).await;
(config.name.clone(), result)
})
})
.collect();
let server_results = join_all(server_futures).await;
let mut results = Vec::new();
let mut errors = Vec::new();
for (server_name, server_result) in server_results {
match server_result {
Ok(tools) => {
for tool in tools {
if criteria.matches(&tool) {
results.push(ToolSearchMatch {
server_name: server_name.clone(),
tool,
});
}
}
}
Err(e) => {
let error_msg = format!("Error connecting to server {}: {}", server_name, e);
if options.continue_on_error {
errors.push(error_msg);
} else {
return Err(e);
}
}
}
}
if !errors.is_empty() && options.continue_on_error {
for error in &errors {
eprintln!("{}", error);
}
}
match options.sort_order {
SortOrder::ServerThenTool => {
results.sort_by(|a, b| {
a.server_name
.cmp(&b.server_name)
.then_with(|| a.tool_name().cmp(b.tool_name()))
});
}
SortOrder::ToolThenServer => {
results.sort_by(|a, b| {
a.tool_name()
.cmp(b.tool_name())
.then_with(|| a.server_name.cmp(&b.server_name))
});
}
SortOrder::None => {
}
}
if let Some(max) = options.max_results {
results.truncate(max);
}
Ok(results)
}
pub async fn search_tools_with_query(
servers: &[ServerConfig],
query: &str,
) -> Result<Vec<ToolSearchMatch>, ToolSearchError> {
let criteria = SearchCriteria::with_query(query.to_string());
search_tools(servers, &criteria).await
}
pub async fn search_tools_with_regex(
servers: &[ServerConfig],
pattern: &str,
) -> Result<Vec<ToolSearchMatch>, ToolSearchError> {
let criteria = SearchCriteria::with_regex(pattern.to_string());
search_tools(servers, &criteria).await
}
pub async fn search_tools_with_keywords(
servers: &[ServerConfig],
keywords: Vec<String>,
) -> Result<Vec<ToolSearchMatch>, ToolSearchError> {
let criteria = SearchCriteria::with_keywords(keywords);
search_tools(servers, &criteria).await
}
pub async fn list_all_tools(
servers: &[ServerConfig],
) -> Result<Vec<ToolSearchMatch>, ToolSearchError> {
let criteria = SearchCriteria {
query: None,
name: None,
mode: SearchMode::Substring,
fields: SearchFields::default(),
case_sensitive: false,
min_description_length: None,
keywords: vec![],
regex: None,
};
search_tools(servers, &criteria).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_criteria_matches() {
use std::sync::Arc;
use serde_json::Map;
let tool = Tool {
name: "test_tool".to_string().into(),
title: None,
description: Some("A test tool for testing".to_string().into()),
input_schema: Arc::new(Map::new()),
annotations: None,
icons: None,
output_schema: None,
};
let criteria = SearchCriteria::with_query("test".to_string());
assert!(criteria.matches(&tool));
let criteria = SearchCriteria::with_query("nonexistent".to_string());
assert!(!criteria.matches(&tool));
let criteria = SearchCriteria::with_name("test_tool".to_string());
assert!(criteria.matches(&tool));
let criteria = SearchCriteria::with_name("other_tool".to_string());
assert!(!criteria.matches(&tool));
let criteria = SearchCriteria::with_regex(r"test.*tool".to_string());
assert!(criteria.matches(&tool));
let criteria = SearchCriteria::with_keywords(vec!["test".to_string(), "tool".to_string()]);
assert!(criteria.matches(&tool));
let criteria = SearchCriteria::with_keywords(vec!["test".to_string(), "nonexistent".to_string()]);
assert!(!criteria.matches(&tool));
let criteria = SearchCriteria::with_query("test".to_string())
.with_mode(SearchMode::WordBoundary);
assert!(criteria.matches(&tool));
}
}