use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::context::JobContext;
use crate::secrets::SecretsStore;
use crate::tools::mcp::auth::refresh_access_token;
use crate::tools::mcp::config::McpServerConfig;
use crate::tools::mcp::http_transport::HttpMcpTransport;
use crate::tools::mcp::protocol::{
CallToolResult, InitializeResult, ListToolsResult, McpRequest, McpResponse, McpTool,
};
use crate::tools::mcp::session::McpSessionManager;
use crate::tools::mcp::transport::McpTransport;
use crate::tools::tool::{ApprovalRequirement, Tool, ToolError, ToolOutput};
pub struct McpClient {
transport: Arc<dyn McpTransport>,
server_url: String,
server_name: String,
next_id: AtomicU64,
tools_cache: RwLock<Option<Vec<McpTool>>>,
session_manager: Option<Arc<McpSessionManager>>,
secrets: Option<Arc<dyn SecretsStore + Send + Sync>>,
user_id: String,
server_config: Option<McpServerConfig>,
custom_headers: HashMap<String, String>,
initialized: tokio::sync::OnceCell<InitializeResult>,
}
impl McpClient {
pub fn new(server_url: impl Into<String>) -> Self {
let url: String = server_url.into();
let name = extract_server_name(&url);
let transport = Arc::new(HttpMcpTransport::new(url.clone(), name.clone()));
Self {
transport,
server_url: url,
server_name: name,
next_id: AtomicU64::new(1),
tools_cache: RwLock::new(None),
session_manager: None,
secrets: None,
user_id: "default".to_string(),
server_config: None,
custom_headers: HashMap::new(),
initialized: tokio::sync::OnceCell::new(),
}
}
pub fn new_with_name(server_name: impl Into<String>, server_url: impl Into<String>) -> Self {
let name: String = server_name.into();
let url: String = server_url.into();
let transport = Arc::new(HttpMcpTransport::new(url.clone(), name.clone()));
Self {
transport,
server_url: url,
server_name: name,
next_id: AtomicU64::new(1),
tools_cache: RwLock::new(None),
session_manager: None,
secrets: None,
user_id: "default".to_string(),
server_config: None,
custom_headers: HashMap::new(),
initialized: tokio::sync::OnceCell::new(),
}
}
#[cfg(test)]
pub fn new_with_config(config: McpServerConfig) -> Result<Self, ToolError> {
if !matches!(
config.effective_transport(),
crate::tools::mcp::config::EffectiveTransport::Http
) {
return Err(ToolError::InvalidParameters(
"new_with_config only supports HTTP transport; use new_with_transport for stdio/UDS"
.to_string(),
));
}
let transport = Arc::new(HttpMcpTransport::new(
config.url.clone(),
config.name.clone(),
));
Ok(Self {
transport,
server_url: config.url.clone(),
server_name: config.name.clone(),
next_id: AtomicU64::new(1),
tools_cache: RwLock::new(None),
session_manager: None,
secrets: None,
user_id: "default".to_string(),
custom_headers: config.headers.clone(),
initialized: tokio::sync::OnceCell::new(),
server_config: Some(config),
})
}
pub fn new_authenticated(
config: McpServerConfig,
session_manager: Arc<McpSessionManager>,
secrets: Arc<dyn SecretsStore + Send + Sync>,
user_id: impl Into<String>,
) -> Self {
let transport = Arc::new(
HttpMcpTransport::new(config.url.clone(), config.name.clone())
.with_session_manager(session_manager.clone()),
);
let custom_headers = config.headers.clone();
Self {
transport,
server_url: config.url.clone(),
server_name: config.name.clone(),
next_id: AtomicU64::new(1),
tools_cache: RwLock::new(None),
session_manager: Some(session_manager),
secrets: Some(secrets),
user_id: user_id.into(),
server_config: Some(config),
custom_headers,
initialized: tokio::sync::OnceCell::new(),
}
}
pub fn new_with_transport(
server_name: impl Into<String>,
transport: Arc<dyn McpTransport>,
session_manager: Option<Arc<McpSessionManager>>,
secrets: Option<Arc<dyn SecretsStore + Send + Sync>>,
user_id: impl Into<String>,
server_config: Option<McpServerConfig>,
) -> Self {
let name: String = server_name.into();
let url = server_config
.as_ref()
.map(|c| c.url.clone())
.unwrap_or_default();
let custom_headers = server_config
.as_ref()
.map(|c| c.headers.clone())
.unwrap_or_default();
Self {
transport,
server_url: url,
server_name: name,
next_id: AtomicU64::new(1),
tools_cache: RwLock::new(None),
session_manager,
secrets,
user_id: user_id.into(),
server_config,
custom_headers,
initialized: tokio::sync::OnceCell::new(),
}
}
#[cfg(test)]
pub fn with_session_manager(mut self, session_manager: Arc<McpSessionManager>) -> Self {
self.session_manager = Some(session_manager);
self
}
pub fn server_name(&self) -> &str {
&self.server_name
}
pub fn server_url(&self) -> &str {
&self.server_url
}
pub fn has_session_manager(&self) -> bool {
self.session_manager.is_some()
}
#[cfg(test)]
pub(crate) fn transport(&self) -> &Arc<dyn McpTransport> {
&self.transport
}
fn next_request_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::SeqCst)
}
async fn get_access_token(&self) -> Result<Option<String>, ToolError> {
let Some(ref secrets) = self.secrets else {
return Ok(None);
};
let Some(ref config) = self.server_config else {
return Ok(None);
};
match secrets
.get_decrypted(&self.user_id, &config.token_secret_name())
.await
{
Ok(token) => Ok(Some(token.expose().to_string())),
Err(crate::secrets::SecretError::NotFound(_)) => Ok(None),
Err(crate::secrets::SecretError::Expired) => {
tracing::info!(
server = %self.server_name,
"Access token expired, attempting refresh"
);
match refresh_access_token(config, secrets, &self.user_id).await {
Ok(new_token) => {
tracing::info!(
server = %self.server_name,
"Access token refreshed successfully"
);
Ok(Some(new_token.access_token))
}
Err(e) => {
tracing::warn!(
server = %self.server_name,
"Token refresh failed: {}", e
);
Err(ToolError::ExternalService(format!(
"Failed to get access token: Secret has expired \
and refresh failed: {}",
e
)))
}
}
}
Err(e) => Err(ToolError::ExternalService(format!(
"Failed to get access token: {}",
e
))),
}
}
async fn build_request_headers(&self) -> Result<HashMap<String, String>, ToolError> {
let mut headers = self.custom_headers.clone();
let has_custom_auth = self
.custom_headers
.keys()
.any(|k| k.eq_ignore_ascii_case("authorization"));
if !has_custom_auth && let Some(token) = self.get_access_token().await? {
let trimmed = token.trim();
if !trimmed.is_empty() {
headers.insert("Authorization".to_string(), format!("Bearer {}", trimmed));
}
}
if let Some(ref session_manager) = self.session_manager
&& let Some(session_id) = session_manager.get_session_id(&self.server_name).await
{
headers.insert("Mcp-Session-Id".to_string(), session_id);
}
Ok(headers)
}
async fn reinitialize_session(&self) -> Result<InitializeResult, ToolError> {
if let Some(ref session_manager) = self.session_manager {
session_manager.terminate(&self.server_name).await;
session_manager
.get_or_create(&self.server_name, &self.server_url)
.await;
}
let request = McpRequest::initialize(self.next_request_id());
let response = self
.transport
.send(&request, &self.build_request_headers().await?)
.await?;
if let Some(error) = response.error {
return Err(ToolError::ExternalService(format!(
"MCP initialization error: {} (code {})",
error.message, error.code
)));
}
let init_result: InitializeResult = response
.result
.ok_or_else(|| {
ToolError::ExternalService("No result in initialize response".to_string())
})
.and_then(|r| {
serde_json::from_value(r).map_err(|e| {
ToolError::ExternalService(format!("Invalid initialize result: {}", e))
})
})?;
if let Some(ref session_manager) = self.session_manager {
session_manager.mark_initialized(&self.server_name).await;
}
let notification = McpRequest::initialized_notification();
if let Err(e) = self
.transport
.send(¬ification, &self.build_request_headers().await?)
.await
{
tracing::debug!(
"Failed to send initialized notification to '{}': {}",
self.server_name,
e
);
}
Ok(init_result)
}
fn is_session_expiry_error(message: &str) -> bool {
let lower = message.to_ascii_lowercase();
lower.contains("session")
&& (lower.contains("400")
|| lower.contains("missing session id")
|| lower.contains("no valid session id"))
}
async fn send_request(&self, request: McpRequest) -> Result<McpResponse, ToolError> {
if !self.transport.supports_http_features() {
let headers = self.build_request_headers().await?;
return self.transport.send(&request, &headers).await;
}
for attempt in 0..2 {
let headers = self.build_request_headers().await?;
let result = self.transport.send(&request, &headers).await;
match result {
Ok(response) => return Ok(response),
Err(ToolError::ExternalService(ref msg))
if attempt == 0
&& self.session_manager.is_some()
&& Self::is_session_expiry_error(msg) =>
{
tracing::debug!(
"MCP session expired, attempting reinitialize for '{}'",
self.server_name
);
self.reinitialize_session().await?;
continue;
}
Err(ToolError::ExternalService(ref msg))
if msg.contains("401")
|| msg.contains("Unauthorized")
|| (msg.contains("400") && {
let lower = msg.to_ascii_lowercase();
lower.contains("authorization") || lower.contains("authenticate")
}) =>
{
if attempt == 0
&& let Some(ref secrets) = self.secrets
&& let Some(ref config) = self.server_config
{
tracing::debug!(
"MCP token expired, attempting refresh for '{}'",
self.server_name
);
match refresh_access_token(config, secrets, &self.user_id).await {
Ok(_) => {
tracing::info!("MCP token refreshed for '{}'", self.server_name);
continue;
}
Err(e) => {
tracing::debug!(
"Token refresh failed for '{}': {}",
self.server_name,
e
);
}
}
}
return Err(ToolError::ExternalService(format!(
"MCP server '{}' requires authentication. Run: ironclaw mcp auth {}",
self.server_name, self.server_name
)));
}
Err(e) => return Err(e),
}
}
Err(ToolError::ExternalService(
"MCP request failed after retry".to_string(),
))
}
pub async fn initialize(&self) -> Result<InitializeResult, ToolError> {
let result = self
.initialized
.get_or_try_init(|| async {
if let Some(ref session_manager) = self.session_manager
&& session_manager.is_initialized(&self.server_name).await
{
return Ok(InitializeResult::default());
}
self.reinitialize_session().await
})
.await?;
Ok(result.clone())
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>, ToolError> {
if let Some(tools) = self.tools_cache.read().await.as_ref() {
return Ok(tools.clone());
}
self.initialize().await?;
let request = McpRequest::list_tools(self.next_request_id());
let response = self.send_request(request).await?;
if let Some(error) = response.error {
return Err(ToolError::ExternalService(format!(
"MCP error: {} (code {})",
error.message, error.code
)));
}
let result: ListToolsResult = response
.result
.ok_or_else(|| ToolError::ExternalService("No result in MCP response".to_string()))
.and_then(|r| {
serde_json::from_value(r)
.map_err(|e| ToolError::ExternalService(format!("Invalid tools list: {}", e)))
})?;
*self.tools_cache.write().await = Some(result.tools.clone());
Ok(result.tools)
}
pub async fn call_tool(
&self,
name: &str,
arguments: serde_json::Value,
) -> Result<CallToolResult, ToolError> {
self.initialize().await?;
let request = McpRequest::call_tool(self.next_request_id(), name, arguments);
let response = self.send_request(request).await?;
if let Some(error) = response.error {
return Err(ToolError::ExecutionFailed(format!(
"MCP tool error: {} (code {})",
error.message, error.code
)));
}
response
.result
.ok_or_else(|| ToolError::ExternalService("No result in MCP response".to_string()))
.and_then(|r| {
serde_json::from_value(r)
.map_err(|e| ToolError::ExternalService(format!("Invalid tool result: {}", e)))
})
}
pub async fn clear_cache(&self) {
*self.tools_cache.write().await = None;
}
pub async fn create_tools(&self) -> Result<Vec<Arc<dyn Tool>>, ToolError> {
let mcp_tools = self.list_tools().await?;
let client = Arc::new(self.clone());
Ok(mcp_tools
.into_iter()
.map(|t| {
let prefixed_name = format!("{}_{}", self.server_name, t.name);
Arc::new(McpToolWrapper {
tool: t,
prefixed_name,
client: client.clone(),
}) as Arc<dyn Tool>
})
.collect())
}
pub async fn test_connection(&self) -> Result<(), ToolError> {
self.initialize().await?;
self.list_tools().await?;
Ok(())
}
}
impl Clone for McpClient {
fn clone(&self) -> Self {
Self {
transport: self.transport.clone(),
server_url: self.server_url.clone(),
server_name: self.server_name.clone(),
next_id: AtomicU64::new(self.next_id.load(Ordering::SeqCst)),
tools_cache: RwLock::new(None),
session_manager: self.session_manager.clone(),
secrets: self.secrets.clone(),
user_id: self.user_id.clone(),
server_config: self.server_config.clone(),
custom_headers: self.custom_headers.clone(),
initialized: tokio::sync::OnceCell::new(),
}
}
}
fn extract_server_name(url: &str) -> String {
reqwest::Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(|h| h.to_string()))
.unwrap_or_else(|| "unknown".to_string())
.replace('.', "_")
}
struct McpToolWrapper {
tool: McpTool,
prefixed_name: String,
client: Arc<McpClient>,
}
#[async_trait]
impl Tool for McpToolWrapper {
fn name(&self) -> &str {
&self.prefixed_name
}
fn description(&self) -> &str {
&self.tool.description
}
fn parameters_schema(&self) -> serde_json::Value {
self.tool.input_schema.clone()
}
async fn execute(
&self,
params: serde_json::Value,
_ctx: &JobContext,
) -> Result<ToolOutput, ToolError> {
let start = std::time::Instant::now();
let params = strip_top_level_nulls(params);
let result = self.client.call_tool(&self.tool.name, params).await?;
let content: String = result
.content
.iter()
.filter_map(|b| b.as_text())
.collect::<Vec<_>>()
.join("\n");
if result.is_error {
return Err(ToolError::ExecutionFailed(content));
}
Ok(ToolOutput::text(content, start.elapsed()))
}
fn requires_sanitization(&self) -> bool {
true
}
fn requires_approval(&self, _params: &serde_json::Value) -> ApprovalRequirement {
if self.tool.requires_approval() {
ApprovalRequirement::UnlessAutoApproved
} else {
ApprovalRequirement::Never
}
}
}
fn strip_top_level_nulls(value: serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let filtered = map.into_iter().filter(|(_, v)| !v.is_null()).collect();
serde_json::Value::Object(filtered)
}
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcp_request_list_tools() {
let req = McpRequest::list_tools(1);
assert_eq!(req.method, "tools/list");
assert_eq!(req.id, Some(1));
}
#[test]
fn test_mcp_request_call_tool() {
let req = McpRequest::call_tool(2, "test", serde_json::json!({"key": "value"}));
assert_eq!(req.method, "tools/call");
assert!(req.params.is_some());
}
#[test]
fn test_extract_server_name() {
assert_eq!(
extract_server_name("https://mcp.notion.com/v1"),
"mcp_notion_com"
);
assert_eq!(extract_server_name("http://localhost:8080"), "localhost");
assert_eq!(extract_server_name("invalid"), "unknown");
}
#[test]
fn test_simple_client_creation() {
let client = McpClient::new("http://localhost:8080");
assert_eq!(client.server_url(), "http://localhost:8080");
assert!(client.session_manager.is_none());
assert!(client.secrets.is_none());
}
#[test]
fn test_extract_server_name_with_port() {
assert_eq!(
extract_server_name("http://example.com:3000"),
"example_com"
);
}
#[test]
fn test_extract_server_name_with_path() {
assert_eq!(
extract_server_name("http://api.server.io/v2/mcp"),
"api_server_io"
);
}
#[test]
fn test_extract_server_name_with_query_params() {
assert_eq!(
extract_server_name("http://mcp.example.com/endpoint?token=abc&v=1"),
"mcp_example_com"
);
}
#[test]
fn test_extract_server_name_https() {
assert_eq!(
extract_server_name("https://secure.mcp.dev"),
"secure_mcp_dev"
);
}
#[test]
fn test_extract_server_name_ip_address() {
assert_eq!(
extract_server_name("http://192.168.1.100:9090/mcp"),
"192_168_1_100"
);
}
#[test]
fn test_new_defaults() {
let client = McpClient::new("http://localhost:9999");
assert_eq!(client.server_url(), "http://localhost:9999");
assert_eq!(client.server_name(), "localhost");
assert!(client.session_manager.is_none());
assert!(client.secrets.is_none());
assert_eq!(client.user_id, "default");
}
#[test]
fn test_new_with_name_uses_custom_name() {
let client = McpClient::new_with_name("my-server", "http://localhost:8080");
assert_eq!(client.server_name(), "my-server");
assert_eq!(client.server_url(), "http://localhost:8080");
assert_eq!(client.user_id, "default");
assert!(client.session_manager.is_none());
assert!(client.secrets.is_none());
}
#[test]
fn test_server_name_accessor() {
let client = McpClient::new("https://tools.example.org/mcp");
assert_eq!(client.server_name(), "tools_example_org");
}
#[test]
fn test_server_url_accessor() {
let url = "https://tools.example.org/mcp?v=2";
let client = McpClient::new(url);
assert_eq!(client.server_url(), url);
}
#[test]
fn test_clone_preserves_fields() {
let client = McpClient::new_with_name("cloned-server", "http://localhost:5555");
client.next_request_id();
client.next_request_id();
let cloned = client.clone();
assert_eq!(cloned.server_url(), "http://localhost:5555");
assert_eq!(cloned.server_name(), "cloned-server");
assert_eq!(cloned.user_id, "default");
assert_eq!(cloned.next_id.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_clone_resets_tools_cache() {
let client = McpClient::new("http://localhost:5555");
let cloned = client.clone();
let cache = cloned.tools_cache.read().await;
assert!(cache.is_none());
}
#[test]
fn test_new_with_config_carries_custom_headers() {
let mut headers = HashMap::new();
headers.insert("X-API-Key".to_string(), "secret".to_string());
headers.insert("X-Custom".to_string(), "value".to_string());
let config = McpServerConfig::new("test", "http://localhost:8080").with_headers(headers);
let client = McpClient::new_with_config(config.clone()).expect("HTTP config should work");
assert_eq!(client.server_name(), "test");
assert_eq!(client.server_url(), "http://localhost:8080");
assert_eq!(client.custom_headers.len(), 2);
assert_eq!(client.custom_headers.get("X-API-Key").unwrap(), "secret");
assert!(client.server_config.is_some());
}
#[test]
fn test_new_with_config_no_headers() {
let config = McpServerConfig::new("bare", "http://localhost:9090");
let client = McpClient::new_with_config(config).expect("HTTP config should work");
assert_eq!(client.server_name(), "bare");
assert!(client.custom_headers.is_empty());
assert!(client.secrets.is_none());
assert!(client.session_manager.is_none());
}
#[test]
fn test_with_session_manager() {
let client = McpClient::new("http://localhost:8080");
assert!(!client.has_session_manager());
let session_manager = Arc::new(McpSessionManager::new());
let client = client.with_session_manager(session_manager);
assert!(client.has_session_manager());
}
#[test]
fn test_next_request_id_monotonically_increasing() {
let client = McpClient::new("http://localhost:1234");
assert_eq!(client.next_request_id(), 1);
assert_eq!(client.next_request_id(), 2);
assert_eq!(client.next_request_id(), 3);
}
#[test]
fn test_mcp_tool_requires_approval_destructive() {
use crate::tools::mcp::protocol::{McpTool, McpToolAnnotations};
let tool = McpTool {
name: "delete_all".to_string(),
description: "Deletes everything".to_string(),
input_schema: serde_json::json!({"type": "object"}),
annotations: Some(McpToolAnnotations {
destructive_hint: true,
side_effects_hint: false,
read_only_hint: false,
execution_time_hint: None,
}),
};
assert!(tool.requires_approval());
}
#[test]
fn test_mcp_tool_no_approval_when_not_destructive() {
use crate::tools::mcp::protocol::{McpTool, McpToolAnnotations};
let tool = McpTool {
name: "read_data".to_string(),
description: "Reads data".to_string(),
input_schema: serde_json::json!({"type": "object"}),
annotations: Some(McpToolAnnotations {
destructive_hint: false,
side_effects_hint: true,
read_only_hint: false,
execution_time_hint: None,
}),
};
assert!(!tool.requires_approval());
}
#[test]
fn test_mcp_tool_no_approval_when_no_annotations() {
use crate::tools::mcp::protocol::McpTool;
let tool = McpTool {
name: "simple_tool".to_string(),
description: "A simple tool".to_string(),
input_schema: serde_json::json!({"type": "object"}),
annotations: None,
};
assert!(!tool.requires_approval());
}
struct MockTransport {
supports_http: bool,
responses: std::sync::Mutex<Vec<McpResponse>>,
recorded_headers: std::sync::Mutex<Vec<HashMap<String, String>>>,
}
impl MockTransport {
fn new(supports_http: bool, responses: Vec<McpResponse>) -> Self {
Self {
supports_http,
responses: std::sync::Mutex::new(responses),
recorded_headers: std::sync::Mutex::new(Vec::new()),
}
}
fn recorded_headers(&self) -> Vec<HashMap<String, String>> {
self.recorded_headers.lock().unwrap().clone()
}
}
#[async_trait]
impl McpTransport for MockTransport {
async fn send(
&self,
_request: &McpRequest,
headers: &HashMap<String, String>,
) -> Result<McpResponse, ToolError> {
self.recorded_headers.lock().unwrap().push(headers.clone());
let mut responses = self.responses.lock().unwrap();
if responses.is_empty() {
return Err(ToolError::ExternalService(
"No more mock responses".to_string(),
));
}
Ok(responses.remove(0))
}
async fn shutdown(&self) -> Result<(), ToolError> {
Ok(())
}
fn supports_http_features(&self) -> bool {
self.supports_http
}
}
struct RetryMockTransport {
supports_http: bool,
outcomes: std::sync::Mutex<std::collections::VecDeque<Result<McpResponse, ToolError>>>,
recorded_headers: std::sync::Mutex<Vec<HashMap<String, String>>>,
}
impl RetryMockTransport {
fn new(supports_http: bool, outcomes: Vec<Result<McpResponse, ToolError>>) -> Self {
Self {
supports_http,
outcomes: std::sync::Mutex::new(outcomes.into()),
recorded_headers: std::sync::Mutex::new(Vec::new()),
}
}
fn recorded_headers(&self) -> Vec<HashMap<String, String>> {
self.recorded_headers.lock().unwrap().clone()
}
}
#[async_trait]
impl McpTransport for RetryMockTransport {
async fn send(
&self,
_request: &McpRequest,
headers: &HashMap<String, String>,
) -> Result<McpResponse, ToolError> {
self.recorded_headers.lock().unwrap().push(headers.clone());
let mut outcomes = self.outcomes.lock().unwrap();
if outcomes.is_empty() {
return Err(ToolError::ExternalService(
"No more mock outcomes".to_string(),
));
}
outcomes.pop_front().unwrap()
}
async fn shutdown(&self) -> Result<(), ToolError> {
Ok(())
}
fn supports_http_features(&self) -> bool {
self.supports_http
}
}
#[tokio::test]
async fn test_non_http_transport_skips_401_retry() {
let init_response = McpResponse {
jsonrpc: "2.0".to_string(),
id: Some(1),
result: Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"serverInfo": {"name": "test", "version": "1.0"}
})),
error: None,
};
let notification_ack = McpResponse {
jsonrpc: "2.0".to_string(),
id: None,
result: None,
error: None,
};
let list_response = McpResponse {
jsonrpc: "2.0".to_string(),
id: Some(2),
result: Some(serde_json::json!({"tools": []})),
error: None,
};
let transport = Arc::new(MockTransport::new(
false,
vec![init_response, notification_ack, list_response],
));
let client = McpClient::new_with_transport(
"test-stdio",
transport.clone(),
None,
None,
"default",
None,
);
let result = client.list_tools().await;
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 0);
let headers = transport.recorded_headers();
assert_eq!(headers.len(), 3);
assert!(!headers[0].contains_key("Authorization"));
assert!(!headers[0].contains_key("Mcp-Session-Id"));
}
#[tokio::test]
async fn test_transport_supports_http_features_accessor() {
let http_transport = HttpMcpTransport::new("http://localhost:8080", "test");
assert!(http_transport.supports_http_features());
let mock_non_http = MockTransport::new(false, vec![]);
assert!(!mock_non_http.supports_http_features());
}
#[tokio::test]
async fn test_stdio_client_auto_initializes_without_session_manager() {
let init_response = McpResponse {
jsonrpc: "2.0".to_string(),
id: Some(1),
result: Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"serverInfo": {"name": "test", "version": "1.0"}
})),
error: None,
};
let notification_ack = McpResponse {
jsonrpc: "2.0".to_string(),
id: None,
result: None,
error: None,
};
let transport = Arc::new(MockTransport::new(
false,
vec![init_response, notification_ack],
));
let client = McpClient::new_with_transport(
"test-stdio",
transport.clone(),
None, None,
"default",
None,
);
let result = client.initialize().await;
assert!(result.is_ok());
assert_eq!(transport.recorded_headers().len(), 2);
let result2 = client.initialize().await;
assert!(result2.is_ok());
assert_eq!(transport.recorded_headers().len(), 2); }
#[tokio::test]
async fn test_http_session_error_triggers_reinitialize_and_retry() {
let init_response = McpResponse {
jsonrpc: "2.0".to_string(),
id: Some(1),
result: Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"serverInfo": {"name": "test", "version": "1.0"}
})),
error: None,
};
let notification_ack = McpResponse {
jsonrpc: "2.0".to_string(),
id: None,
result: None,
error: None,
};
let notification_ack2 = notification_ack.clone();
let session_error = Err(ToolError::ExternalService(
"[test] MCP server returned status: 400 - No valid session ID provided".to_string(),
));
let reinit_response = McpResponse {
jsonrpc: "2.0".to_string(),
id: Some(2),
result: Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"serverInfo": {"name": "test", "version": "1.0"}
})),
error: None,
};
let call_response = McpResponse {
jsonrpc: "2.0".to_string(),
id: Some(3),
result: Some(serde_json::json!({
"content": [{"type": "text", "text": "pong"}],
"is_error": false
})),
error: None,
};
let transport = Arc::new(RetryMockTransport::new(
true,
vec![
Ok(init_response),
Ok(notification_ack),
session_error,
Ok(reinit_response),
Ok(notification_ack2),
Ok(call_response),
],
));
let session_manager = Arc::new(McpSessionManager::new());
let client = McpClient::new_with_transport(
"test-http",
transport.clone(),
Some(session_manager),
None,
"default",
None,
);
client.initialize().await.expect("initial handshake");
let result = client
.call_tool("echo", serde_json::json!({"input": "hello"}))
.await
.expect("call should recover after session expiry");
assert!(!result.is_error);
assert_eq!(result.content.len(), 1);
assert_eq!(result.content[0].as_text(), Some("pong"));
let headers = transport.recorded_headers();
assert_eq!(headers.len(), 6);
}
#[test]
fn test_strip_top_level_nulls_removes_null_fields() {
let input = serde_json::json!({
"query": "search term",
"sort": null,
"filter": null,
"page_size": 10
});
let result = strip_top_level_nulls(input);
let obj = result.as_object().unwrap();
assert_eq!(obj.len(), 2);
assert_eq!(obj["query"], "search term");
assert_eq!(obj["page_size"], 10);
assert!(!obj.contains_key("sort"));
assert!(!obj.contains_key("filter"));
}
#[test]
fn test_strip_top_level_nulls_preserves_non_objects() {
let input = serde_json::json!("just a string");
let result = strip_top_level_nulls(input.clone());
assert_eq!(result, input);
}
#[test]
fn test_strip_top_level_nulls_preserves_nested_nulls() {
let input = serde_json::json!({
"outer": { "inner": null },
"top_null": null
});
let result = strip_top_level_nulls(input);
let obj = result.as_object().unwrap();
assert_eq!(obj.len(), 1);
assert!(obj["outer"]["inner"].is_null());
}
#[test]
fn test_new_with_config_rejects_stdio_transport() {
let config = McpServerConfig::new_stdio(
"stdio-server",
"echo",
vec!["hello".to_string()],
HashMap::new(),
);
let result = McpClient::new_with_config(config);
let err = result
.err()
.expect("stdio config must be rejected")
.to_string();
assert!(
err.contains("new_with_config only supports HTTP"),
"error should explain the restriction: {}",
err
);
}
fn make_test_mcp_tool(destructive: bool) -> McpTool {
use crate::tools::mcp::protocol::McpToolAnnotations;
McpTool {
name: "do_thing".to_string(),
description: "Does a thing".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"input": {"type": "string"}
}
}),
annotations: if destructive {
Some(McpToolAnnotations {
destructive_hint: true,
side_effects_hint: false,
read_only_hint: false,
execution_time_hint: None,
})
} else {
None
},
}
}
#[test]
fn test_mcp_tool_wrapper_name_is_prefixed() {
let client = Arc::new(McpClient::new("http://localhost:8080"));
let wrapper = McpToolWrapper {
tool: make_test_mcp_tool(false),
prefixed_name: "mcp__myserver__do_thing".to_string(),
client,
};
assert_eq!(wrapper.name(), "mcp__myserver__do_thing");
}
#[test]
fn test_mcp_tool_wrapper_description() {
let client = Arc::new(McpClient::new("http://localhost:8080"));
let wrapper = McpToolWrapper {
tool: make_test_mcp_tool(false),
prefixed_name: "mcp__s__do_thing".to_string(),
client,
};
assert_eq!(wrapper.description(), "Does a thing");
}
#[test]
fn test_mcp_tool_wrapper_parameters_schema() {
let client = Arc::new(McpClient::new("http://localhost:8080"));
let wrapper = McpToolWrapper {
tool: make_test_mcp_tool(false),
prefixed_name: "mcp__s__do_thing".to_string(),
client,
};
let schema = wrapper.parameters_schema();
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["input"].is_object());
}
#[test]
fn test_mcp_tool_wrapper_requires_sanitization() {
let client = Arc::new(McpClient::new("http://localhost:8080"));
let wrapper = McpToolWrapper {
tool: make_test_mcp_tool(false),
prefixed_name: "mcp__s__do_thing".to_string(),
client,
};
assert!(
wrapper.requires_sanitization(),
"MCP tools should always require sanitization"
);
}
#[test]
fn test_mcp_tool_wrapper_approval_destructive() {
let client = Arc::new(McpClient::new("http://localhost:8080"));
let wrapper = McpToolWrapper {
tool: make_test_mcp_tool(true),
prefixed_name: "mcp__s__do_thing".to_string(),
client,
};
let approval = wrapper.requires_approval(&serde_json::json!({}));
assert_eq!(approval, ApprovalRequirement::UnlessAutoApproved);
}
#[test]
fn test_mcp_tool_wrapper_approval_non_destructive() {
let client = Arc::new(McpClient::new("http://localhost:8080"));
let wrapper = McpToolWrapper {
tool: make_test_mcp_tool(false),
prefixed_name: "mcp__s__do_thing".to_string(),
client,
};
let approval = wrapper.requires_approval(&serde_json::json!({}));
assert_eq!(approval, ApprovalRequirement::Never);
}
#[tokio::test]
async fn test_build_headers_skips_empty_token() {
use crate::secrets::{CreateSecretParams, DecryptedSecret, Secret, SecretError, SecretRef};
use uuid::Uuid;
struct EmptyTokenStore;
#[async_trait]
impl crate::secrets::SecretsStore for EmptyTokenStore {
async fn create(
&self,
_user_id: &str,
_params: CreateSecretParams,
) -> Result<Secret, SecretError> {
unimplemented!()
}
async fn get(&self, _user_id: &str, _name: &str) -> Result<Secret, SecretError> {
unimplemented!()
}
async fn get_decrypted(
&self,
_user_id: &str,
_name: &str,
) -> Result<DecryptedSecret, SecretError> {
DecryptedSecret::from_bytes(b" ".to_vec())
}
async fn exists(&self, _user_id: &str, _name: &str) -> Result<bool, SecretError> {
Ok(true)
}
async fn delete(&self, _user_id: &str, _name: &str) -> Result<bool, SecretError> {
Ok(true)
}
async fn list(&self, _user_id: &str) -> Result<Vec<SecretRef>, SecretError> {
Ok(Vec::new())
}
async fn record_usage(&self, _secret_id: Uuid) -> Result<(), SecretError> {
Ok(())
}
async fn is_accessible(
&self,
_user_id: &str,
_secret_name: &str,
_allowed_secrets: &[String],
) -> Result<bool, SecretError> {
Ok(true)
}
}
let config = McpServerConfig::new("github", "https://api.githubcopilot.com/mcp/");
let session_manager = Arc::new(McpSessionManager::new());
let secrets: Arc<dyn crate::secrets::SecretsStore + Send + Sync> =
Arc::new(EmptyTokenStore);
let client = McpClient::new_authenticated(config, session_manager, secrets, "test-user");
let headers = client.build_request_headers().await.unwrap(); assert!(
!headers.contains_key("Authorization"),
"Empty/whitespace token must not produce an Authorization header, got: {:?}",
headers.get("Authorization")
);
}
#[tokio::test]
async fn test_build_headers_trims_token() {
use crate::secrets::{CreateSecretParams, DecryptedSecret, Secret, SecretError, SecretRef};
use uuid::Uuid;
struct PaddedTokenStore;
#[async_trait]
impl crate::secrets::SecretsStore for PaddedTokenStore {
async fn create(
&self,
_user_id: &str,
_params: CreateSecretParams,
) -> Result<Secret, SecretError> {
unimplemented!()
}
async fn get(&self, _user_id: &str, _name: &str) -> Result<Secret, SecretError> {
unimplemented!()
}
async fn get_decrypted(
&self,
_user_id: &str,
_name: &str,
) -> Result<DecryptedSecret, SecretError> {
DecryptedSecret::from_bytes(b" gho_abc123 \n".to_vec())
}
async fn exists(&self, _user_id: &str, _name: &str) -> Result<bool, SecretError> {
Ok(true)
}
async fn delete(&self, _user_id: &str, _name: &str) -> Result<bool, SecretError> {
Ok(true)
}
async fn list(&self, _user_id: &str) -> Result<Vec<SecretRef>, SecretError> {
Ok(Vec::new())
}
async fn record_usage(&self, _secret_id: Uuid) -> Result<(), SecretError> {
Ok(())
}
async fn is_accessible(
&self,
_user_id: &str,
_secret_name: &str,
_allowed_secrets: &[String],
) -> Result<bool, SecretError> {
Ok(true)
}
}
let config = McpServerConfig::new("github", "https://api.githubcopilot.com/mcp/");
let session_manager = Arc::new(McpSessionManager::new());
let secrets: Arc<dyn crate::secrets::SecretsStore + Send + Sync> =
Arc::new(PaddedTokenStore);
let client = McpClient::new_authenticated(config, session_manager, secrets, "test-user");
let headers = client.build_request_headers().await.unwrap(); assert_eq!(
headers.get("Authorization").unwrap(), "Bearer gho_abc123",
"Token must be trimmed before use in Authorization header"
);
}
}