use crate::config::mcp::McpClientConfig;
pub mod cli;
mod client;
pub mod connection_pool;
pub mod enhanced_config;
pub mod errors;
mod provider;
mod rmcp_client;
pub mod rmcp_transport;
pub mod schema;
pub mod tool_discovery;
pub mod tool_discovery_cache;
pub mod traits;
pub mod types;
pub mod utils;
pub use client::McpClient;
pub use connection_pool::{
ConnectionPoolStats, McpConnectionPool, McpPoolError, PooledMcpManager, PooledMcpStats,
};
pub use errors::{
ErrorCode, McpResult, configuration_error, initialization_timeout, provider_not_found,
provider_unavailable, schema_invalid, tool_invocation_failed, tool_not_found,
};
pub use provider::McpProvider;
pub(crate) use rmcp_client::RmcpClient;
pub use rmcp_transport::{
HttpTransport, create_http_transport, create_stdio_transport,
create_stdio_transport_with_stderr,
};
pub use schema::{validate_against_schema, validate_tool_input};
pub use tool_discovery::{DetailLevel, ToolDiscovery, ToolDiscoveryResult};
pub use traits::{McpElicitationHandler, McpToolExecutor};
pub use types::{
FileParamSchemaEntry, FileUploadResult, McpClientStatus, McpElicitationRequest,
McpElicitationResponse, McpPromptDetail, McpPromptInfo, McpResourceData, McpResourceInfo,
McpToolInfo, OPENAI_FILE_PARAMS_META_KEY, OPENAI_FILE_PARAMS_VALUE, ProvidedFilePayload,
};
pub use utils::{
LOCAL_TIMEZONE_ENV_VAR, TIMEZONE_ARGUMENT, TZ_ENV_VAR, build_headers, detect_local_timezone,
ensure_timezone_argument, schema_requires_field,
};
use anyhow::{Result, anyhow};
use hashbrown::HashMap;
pub use rmcp::model::ElicitationAction;
use std::ffi::OsString;
pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &[LATEST_PROTOCOL_VERSION];
pub(crate) fn convert_to_rmcp<T, U>(value: T) -> Result<U>
where
T: serde::Serialize,
U: serde::de::DeserializeOwned,
{
let json = serde_json::to_value(value)?;
serde_json::from_value(json).map_err(|err| anyhow!(err))
}
fn create_env_for_mcp_server(
extra_env: Option<HashMap<OsString, OsString>>,
) -> HashMap<OsString, OsString> {
DEFAULT_ENV_VARS
.iter()
.filter_map(|var| std::env::var_os(var).map(|value| (OsString::from(*var), value)))
.chain(extra_env.unwrap_or_default())
.collect()
}
pub fn validate_mcp_config(config: &McpClientConfig) -> Result<()> {
if config.server.enabled {
if config.server.port == 0 {
return Err(anyhow::anyhow!(
"Invalid server port: {}",
config.server.port
));
}
if config.server.bind_address.is_empty() {
return Err(anyhow::anyhow!("Server bind address cannot be empty"));
}
if config.security.auth_enabled && config.security.api_key_env.is_none() {
return Err(anyhow::anyhow!(
"API key environment variable must be set when auth is enabled"
));
}
}
if let Some(startup_timeout) = config.startup_timeout_seconds
&& startup_timeout > 300
{
return Err(anyhow::anyhow!("Startup timeout cannot exceed 300 seconds"));
}
if let Some(tool_timeout) = config.tool_timeout_seconds
&& tool_timeout > 3600
{
return Err(anyhow::anyhow!("Tool timeout cannot exceed 3600 seconds"));
}
for provider in &config.providers {
if provider.name.is_empty() {
return Err(anyhow::anyhow!("MCP provider name cannot be empty"));
}
if provider.max_concurrent_requests == 0 {
return Err(anyhow::anyhow!(
"Max concurrent requests must be greater than 0 for provider '{}'",
provider.name
));
}
}
Ok(())
}
#[cfg(unix)]
const DEFAULT_ENV_VARS: &[&str] = &[
"HOME",
"LOGNAME",
"PATH",
"SHELL",
"USER",
"__CF_USER_TEXT_ENCODING",
"LANG",
"LC_ALL",
"TERM",
"TMPDIR",
"TZ",
];
#[cfg(windows)]
const DEFAULT_ENV_VARS: &[&str] = &[
"PATH",
"PATHEXT",
"COMSPEC",
"SYSTEMROOT",
"SYSTEMDRIVE",
"USERNAME",
"USERDOMAIN",
"USERPROFILE",
"HOMEDRIVE",
"HOMEPATH",
"PROGRAMFILES",
"PROGRAMFILES(X86)",
"PROGRAMW6432",
"PROGRAMDATA",
"LOCALAPPDATA",
"APPDATA",
"TEMP",
"TMP",
"POWERSHELL",
"PWSH",
];
fn sanitize_filename(name: &str) -> String {
name.chars()
.map(|c| {
if c.is_alphanumeric() || c == '_' || c == '-' {
c
} else {
'_'
}
})
.collect()
}
fn format_tool_markdown(tool: &McpToolInfo) -> String {
let mut content = String::new();
content.push_str(&format!("# {}\n\n", tool.name));
content.push_str(&format!("**Provider**: {}\n\n", tool.provider));
content.push_str("## Description\n\n");
content.push_str(&tool.description);
content.push_str("\n\n");
content.push_str("## Input Schema\n\n");
content.push_str("```json\n");
content.push_str(
&serde_json::to_string_pretty(&tool.input_schema)
.unwrap_or_else(|_| tool.input_schema.to_string()),
);
content.push_str("\n```\n\n");
if let Some(obj) = tool.input_schema.as_object() {
if let Some(required) = obj.get("required").and_then(|v| v.as_array())
&& !required.is_empty()
{
content.push_str("## Required Parameters\n\n");
for req in required {
if let Some(name) = req.as_str() {
content.push_str(&format!("- `{}`\n", name));
}
}
content.push('\n');
}
if let Some(props) = obj.get("properties").and_then(|v| v.as_object())
&& !props.is_empty()
{
content.push_str("## Parameters\n\n");
for (param_name, param_schema) in props {
let param_type = param_schema
.get("type")
.and_then(|t| t.as_str())
.unwrap_or("any");
let param_desc = param_schema
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("");
content.push_str(&format!("### `{}`\n\n", param_name));
content.push_str(&format!("- **Type**: {}\n", param_type));
if !param_desc.is_empty() {
content.push_str(&format!("- **Description**: {}\n", param_desc));
}
content.push('\n');
}
}
}
content.push_str("---\n");
content.push_str("*Generated automatically for dynamic context discovery.*\n");
content
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::mcp::{McpProviderConfig, McpStdioServerConfig, McpTransportConfig};
use crate::mcp::rmcp_client::{
build_elicitation_validator, directory_to_file_uri, validate_elicitation_payload,
};
use crate::mcp::utils::{clear_test_env_override, set_test_env_override};
use hashbrown::HashMap;
use serde_json::{Map, Value, json};
use std::ffi::OsString;
use rmcp::model::{
ClientCapabilities, Implementation, InitializeRequestParams, RootsCapabilities,
};
#[cfg(unix)]
use serial_test::serial;
#[cfg(unix)]
use std::os::unix::ffi::OsStringExt;
struct EnvGuard {
key: &'static str,
}
impl EnvGuard {
fn set(key: &'static str, value: &str) -> Self {
set_test_env_override(key, Some(value));
Self { key }
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
clear_test_env_override(self.key);
}
}
#[test]
fn schema_detection_handles_required_entries() {
let schema = json!({
"type": "object",
"required": [TIMEZONE_ARGUMENT],
"properties": {
TIMEZONE_ARGUMENT: { "type": "string" }
}
});
assert!(schema_requires_field(&schema, TIMEZONE_ARGUMENT));
assert!(!schema_requires_field(&schema, "location"));
}
#[test]
fn ensure_timezone_injects_from_override_env() {
let _guard = EnvGuard::set(LOCAL_TIMEZONE_ENV_VAR, "Etc/UTC");
let mut arguments = Map::new();
ensure_timezone_argument(&mut arguments, true).unwrap();
assert_eq!(
arguments.get(TIMEZONE_ARGUMENT).and_then(Value::as_str),
Some("Etc/UTC")
);
}
#[test]
fn ensure_timezone_does_not_override_existing_value() {
let mut arguments = Map::new();
arguments.insert(
TIMEZONE_ARGUMENT.to_string(),
Value::String("America/New_York".to_owned()),
);
ensure_timezone_argument(&mut arguments, true).unwrap();
assert_eq!(
arguments.get(TIMEZONE_ARGUMENT).and_then(Value::as_str),
Some("America/New_York")
);
}
#[test]
fn create_env_merges_configured_values() {
let mut extra_env = HashMap::new();
extra_env.insert(OsString::from("A"), OsString::from("1"));
extra_env.insert(OsString::from("B"), OsString::from("2"));
let env = create_env_for_mcp_server(Some(extra_env));
assert_eq!(env.get(&OsString::from("A")), Some(&OsString::from("1")));
assert_eq!(env.get(&OsString::from("B")), Some(&OsString::from("2")));
}
#[test]
#[cfg(unix)]
#[serial]
#[allow(unsafe_code)]
fn create_env_preserves_non_utf8_path() {
let original_path = std::env::var_os("PATH");
let non_utf8_path = OsString::from_vec(b"/tmp/alpha:\xFFbeta".to_vec());
unsafe {
std::env::set_var("PATH", &non_utf8_path);
}
let env = create_env_for_mcp_server(None);
match original_path {
Some(value) => {
unsafe {
std::env::set_var("PATH", value);
}
}
None => {
unsafe {
std::env::remove_var("PATH");
}
}
}
assert_eq!(env.get(&OsString::from("PATH")), Some(&non_utf8_path));
}
#[tokio::test]
async fn convert_to_rmcp_round_trip() {
let mut capabilities = ClientCapabilities::default();
capabilities.roots = Some(RootsCapabilities {
list_changed: Some(true),
});
let params =
InitializeRequestParams::new(capabilities, Implementation::new("vtcode", "1.0"))
.with_protocol_version(rmcp::model::ProtocolVersion::V_2024_11_05);
let converted: InitializeRequestParams = convert_to_rmcp(params.clone()).unwrap();
assert_eq!(converted.client_info.name, "vtcode");
assert_eq!(converted.client_info.version, "1.0");
}
#[test]
fn validate_elicitation_payload_rejects_invalid_content() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let validator =
build_elicitation_validator("test", &schema).expect("schema should compile");
let result = validate_elicitation_payload(
"test",
Some(&validator),
&ElicitationAction::Accept,
Some(&json!({ "name": 42 })),
);
assert!(result.is_err());
}
#[test]
fn validate_elicitation_payload_accepts_valid_content() {
let schema = json!({
"type": "object",
"properties": {
"email": { "type": "string", "format": "email" }
},
"required": ["email"]
});
let validator =
build_elicitation_validator("test", &schema).expect("schema should compile");
let result = validate_elicitation_payload(
"test",
Some(&validator),
&ElicitationAction::Accept,
Some(&json!({ "email": "user@example.com" })),
);
assert!(result.is_ok());
}
#[tokio::test]
async fn provider_max_concurrency_defaults_to_one() {
let config = McpProviderConfig {
name: "test".into(),
transport: McpTransportConfig::Stdio(McpStdioServerConfig {
command: "cat".into(),
args: vec![],
working_directory: None,
}),
env: HashMap::new(),
enabled: true,
max_concurrent_requests: 0,
startup_timeout_ms: None,
};
let provider = McpProvider::connect(config, None).await.unwrap();
assert_eq!(provider.semaphore.available_permits(), 1);
}
#[test]
fn directory_to_file_uri_generates_file_scheme() {
let temp_dir = std::env::temp_dir();
let uri = directory_to_file_uri(temp_dir.as_path())
.expect("should create file uri for temp directory");
assert!(uri.starts_with("file://"));
}
}