use async_trait::async_trait;
use sha2::Digest;
use std::collections::HashMap;
use std::io::Write;
use std::sync::Arc;
use tempfile::NamedTempFile;
use tokio::sync::RwLock;
use tokio::time::{timeout, Duration};
use super::types::{
McpClientConfig, McpClientError, McpTool, ToolDiscoveryEvent, ToolProvider,
ToolVerificationRequest, ToolVerificationResponse, VerificationStatus,
};
use crate::integrations::schemapin::{
LocalKeyStore, NativeSchemaPinClient, PinnedKey, SchemaPinClient, VerifyArgs,
};
use crate::integrations::tool_invocation::{
DefaultToolInvocationEnforcer, InvocationContext, InvocationResult, ToolInvocationEnforcer,
ToolInvocationError,
};
#[async_trait]
pub trait McpClient: Send + Sync {
async fn discover_tool(&self, tool: McpTool) -> Result<ToolDiscoveryEvent, McpClientError>;
async fn get_tool(&self, name: &str) -> Result<McpTool, McpClientError>;
async fn list_tools(&self) -> Result<Vec<McpTool>, McpClientError>;
async fn list_verified_tools(&self) -> Result<Vec<McpTool>, McpClientError>;
async fn verify_tool(
&self,
request: ToolVerificationRequest,
) -> Result<ToolVerificationResponse, McpClientError>;
async fn remove_tool(&self, name: &str) -> Result<Option<McpTool>, McpClientError>;
async fn invoke_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
context: InvocationContext,
) -> Result<InvocationResult, McpClientError>;
}
pub struct SecureMcpClient {
config: McpClientConfig,
schema_pin: Arc<dyn SchemaPinClient>,
key_store: Arc<LocalKeyStore>,
tools: Arc<RwLock<HashMap<String, McpTool>>>,
enforcer: Arc<dyn ToolInvocationEnforcer>,
http_client: reqwest::Client,
}
impl SecureMcpClient {
pub fn new(
config: McpClientConfig,
schema_pin: Arc<dyn SchemaPinClient>,
key_store: Arc<LocalKeyStore>,
) -> Self {
let enforcer = Arc::new(DefaultToolInvocationEnforcer::new());
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.https_only(true)
.build()
.expect("Failed to build HTTPS-only reqwest client");
Self {
config,
schema_pin,
key_store,
tools: Arc::new(RwLock::new(HashMap::new())),
enforcer,
http_client,
}
}
pub fn with_enforcer(
config: McpClientConfig,
schema_pin: Arc<dyn SchemaPinClient>,
key_store: Arc<LocalKeyStore>,
enforcer: Arc<dyn ToolInvocationEnforcer>,
) -> Self {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.https_only(true)
.build()
.expect("Failed to build HTTPS-only reqwest client");
Self {
config,
schema_pin,
key_store,
tools: Arc::new(RwLock::new(HashMap::new())),
enforcer,
http_client,
}
}
pub fn with_defaults(config: McpClientConfig) -> Result<Self, McpClientError> {
let schema_pin = Arc::new(NativeSchemaPinClient::new());
let key_store = Arc::new(LocalKeyStore::new()?);
Ok(Self::new(config, schema_pin, key_store))
}
async fn verify_schema(&self, tool: &McpTool) -> Result<VerificationStatus, McpClientError> {
let mut temp_file =
NamedTempFile::new().map_err(|e| McpClientError::SerializationError {
reason: format!("Failed to create temp file: {}", e),
})?;
let schema_json = serde_json::to_string_pretty(&tool.schema).map_err(|e| {
McpClientError::SerializationError {
reason: format!("Failed to serialize schema: {}", e),
}
})?;
temp_file.write_all(schema_json.as_bytes()).map_err(|e| {
McpClientError::SerializationError {
reason: format!("Failed to write schema to temp file: {}", e),
}
})?;
let temp_path = temp_file.path().to_string_lossy().to_string();
self.fetch_and_pin_key(&tool.provider).await?;
let verify_args = VerifyArgs::new(temp_path, tool.provider.public_key_url.clone());
let verification_timeout = Duration::from_secs(self.config.verification_timeout_seconds);
let verification_result = timeout(
verification_timeout,
self.schema_pin.verify_schema(verify_args),
)
.await
.map_err(|_| McpClientError::Timeout)?;
match verification_result {
Ok(result) => Ok(VerificationStatus::Verified {
result: Box::new(result),
verified_at: chrono::Utc::now().to_rfc3339(),
}),
Err(e) => Ok(VerificationStatus::Failed {
reason: e.to_string(),
failed_at: chrono::Utc::now().to_rfc3339(),
}),
}
}
async fn fetch_and_pin_key(&self, provider: &ToolProvider) -> Result<(), McpClientError> {
if self.key_store.has_key(&provider.identifier)? {
tracing::debug!(
provider = %provider.identifier,
"Key already pinned, skipping fetch"
);
return Ok(());
}
tracing::info!(
provider = %provider.identifier,
url = %provider.public_key_url,
"Fetching provider public key for TOFU pinning"
);
let response = self
.http_client
.get(&provider.public_key_url)
.send()
.await
.map_err(|e| McpClientError::KeyFetchFailed {
provider: provider.identifier.clone(),
reason: format!("HTTP request failed: {}", e),
})?;
if !response.status().is_success() {
return Err(McpClientError::KeyFetchFailed {
provider: provider.identifier.clone(),
reason: format!(
"Server returned HTTP {} from {}",
response.status(),
provider.public_key_url
),
});
}
let key_data = response
.text()
.await
.map_err(|e| McpClientError::KeyFetchFailed {
provider: provider.identifier.clone(),
reason: format!("Failed to read response body: {}", e),
})?;
if key_data.trim().is_empty() {
return Err(McpClientError::KeyFetchFailed {
provider: provider.identifier.clone(),
reason: "Server returned an empty key".to_string(),
});
}
let mut hasher = sha2::Sha256::new();
hasher.update(key_data.as_bytes());
let fingerprint = hex::encode(hasher.finalize());
let pinned_key = PinnedKey::new(
provider.identifier.clone(),
key_data,
"ES256".to_string(),
fingerprint.clone(),
);
self.key_store.pin_key(pinned_key)?;
tracing::info!(
provider = %provider.identifier,
url = %provider.public_key_url,
fingerprint = %fingerprint,
"Provider public key pinned successfully (TOFU)"
);
Ok(())
}
fn should_allow_tool(&self, tool: &McpTool) -> bool {
match &tool.verification_status {
VerificationStatus::Verified { .. } => true,
VerificationStatus::Failed { .. } => false,
VerificationStatus::Pending => !self.config.enforce_verification,
VerificationStatus::Skipped { .. } => self.config.allow_unverified_in_dev,
}
}
}
#[async_trait]
impl McpClient for SecureMcpClient {
async fn discover_tool(&self, mut tool: McpTool) -> Result<ToolDiscoveryEvent, McpClientError> {
let verification_status = if self.config.enforce_verification {
self.verify_schema(&tool).await?
} else {
VerificationStatus::Skipped {
reason: "Verification disabled in configuration".to_string(),
}
};
tool.verification_status = verification_status;
if self.config.enforce_verification && !self.should_allow_tool(&tool) {
return Err(McpClientError::VerificationFailed {
reason: format!(
"Tool '{}' failed verification and cannot be added",
tool.name
),
});
}
{
let mut tools = self.tools.write().await;
tools.insert(tool.name.clone(), tool.clone());
}
Ok(ToolDiscoveryEvent {
tool,
source: "discovery".to_string(),
discovered_at: chrono::Utc::now().to_rfc3339(),
})
}
async fn get_tool(&self, name: &str) -> Result<McpTool, McpClientError> {
let tools = self.tools.read().await;
let tool = tools
.get(name)
.ok_or_else(|| McpClientError::ToolNotFound {
name: name.to_string(),
})?;
if self.config.enforce_verification && !tool.verification_status.is_verified() {
return Err(McpClientError::ToolNotVerified {
name: name.to_string(),
});
}
Ok(tool.clone())
}
async fn list_tools(&self) -> Result<Vec<McpTool>, McpClientError> {
let tools = self.tools.read().await;
Ok(tools.values().cloned().collect())
}
async fn list_verified_tools(&self) -> Result<Vec<McpTool>, McpClientError> {
let tools = self.tools.read().await;
Ok(tools
.values()
.filter(|tool| tool.verification_status.is_verified())
.cloned()
.collect())
}
async fn verify_tool(
&self,
request: ToolVerificationRequest,
) -> Result<ToolVerificationResponse, McpClientError> {
let mut warnings = Vec::new();
let tool_exists = {
let tools = self.tools.read().await;
tools.contains_key(&request.tool.name)
};
if !request.force_reverify && tool_exists {
let tools = self.tools.read().await;
if let Some(existing_tool) = tools.get(&request.tool.name) {
if existing_tool.verification_status.is_verified() {
warnings
.push("Tool already verified, use force_reverify to re-verify".to_string());
return Ok(ToolVerificationResponse {
tool_name: request.tool.name,
status: existing_tool.verification_status.clone(),
warnings,
});
}
}
}
let verification_status = self.verify_schema(&request.tool).await?;
if tool_exists {
let mut tools = self.tools.write().await;
if let Some(existing_tool) = tools.get_mut(&request.tool.name) {
existing_tool.verification_status = verification_status.clone();
}
}
Ok(ToolVerificationResponse {
tool_name: request.tool.name,
status: verification_status,
warnings,
})
}
async fn remove_tool(&self, name: &str) -> Result<Option<McpTool>, McpClientError> {
let mut tools = self.tools.write().await;
Ok(tools.remove(name))
}
async fn invoke_tool(
&self,
tool_name: &str,
_arguments: serde_json::Value,
context: InvocationContext,
) -> Result<InvocationResult, McpClientError> {
let tool = self.get_tool(tool_name).await?;
self.enforcer
.execute_tool_with_enforcement(&tool, context)
.await
.map_err(|e| match e {
ToolInvocationError::InvocationBlocked {
tool_name,
reason: _,
} => McpClientError::ToolNotVerified { name: tool_name },
ToolInvocationError::ToolNotFound { tool_name } => {
McpClientError::ToolNotFound { name: tool_name }
}
ToolInvocationError::VerificationRequired { tool_name, .. } => {
McpClientError::ToolNotVerified { name: tool_name }
}
ToolInvocationError::VerificationFailed {
tool_name: _,
reason,
} => McpClientError::VerificationFailed { reason },
_ => McpClientError::CommunicationError {
reason: e.to_string(),
},
})
}
}
pub struct MockMcpClient {
tools: Arc<RwLock<HashMap<String, McpTool>>>,
should_verify_successfully: bool,
}
impl MockMcpClient {
pub fn new_success() -> Self {
Self {
tools: Arc::new(RwLock::new(HashMap::new())),
should_verify_successfully: true,
}
}
pub fn new_failure() -> Self {
Self {
tools: Arc::new(RwLock::new(HashMap::new())),
should_verify_successfully: false,
}
}
}
#[async_trait]
impl McpClient for MockMcpClient {
async fn discover_tool(&self, mut tool: McpTool) -> Result<ToolDiscoveryEvent, McpClientError> {
tool.verification_status = if self.should_verify_successfully {
VerificationStatus::Verified {
result: Box::new(crate::integrations::schemapin::VerificationResult {
success: true,
message: "Mock verification successful".to_string(),
schema_hash: Some("mock_hash".to_string()),
public_key_url: Some(tool.provider.public_key_url.clone()),
signature: None,
metadata: None,
timestamp: Some(chrono::Utc::now().to_rfc3339()),
}),
verified_at: chrono::Utc::now().to_rfc3339(),
}
} else {
VerificationStatus::Failed {
reason: "Mock verification failed".to_string(),
failed_at: chrono::Utc::now().to_rfc3339(),
}
};
if !self.should_verify_successfully {
return Err(McpClientError::VerificationFailed {
reason: "Mock verification failed".to_string(),
});
}
let mut tools = self.tools.write().await;
tools.insert(tool.name.clone(), tool.clone());
Ok(ToolDiscoveryEvent {
tool,
source: "mock".to_string(),
discovered_at: chrono::Utc::now().to_rfc3339(),
})
}
async fn get_tool(&self, name: &str) -> Result<McpTool, McpClientError> {
let tools = self.tools.read().await;
tools
.get(name)
.cloned()
.ok_or_else(|| McpClientError::ToolNotFound {
name: name.to_string(),
})
}
async fn list_tools(&self) -> Result<Vec<McpTool>, McpClientError> {
let tools = self.tools.read().await;
Ok(tools.values().cloned().collect())
}
async fn list_verified_tools(&self) -> Result<Vec<McpTool>, McpClientError> {
let tools = self.tools.read().await;
Ok(tools
.values()
.filter(|tool| tool.verification_status.is_verified())
.cloned()
.collect())
}
async fn verify_tool(
&self,
request: ToolVerificationRequest,
) -> Result<ToolVerificationResponse, McpClientError> {
let status = if self.should_verify_successfully {
VerificationStatus::Verified {
result: Box::new(crate::integrations::schemapin::VerificationResult {
success: true,
message: "Mock verification successful".to_string(),
schema_hash: Some("mock_hash".to_string()),
public_key_url: Some(request.tool.provider.public_key_url.clone()),
signature: None,
metadata: None,
timestamp: Some(chrono::Utc::now().to_rfc3339()),
}),
verified_at: chrono::Utc::now().to_rfc3339(),
}
} else {
VerificationStatus::Failed {
reason: "Mock verification failed".to_string(),
failed_at: chrono::Utc::now().to_rfc3339(),
}
};
Ok(ToolVerificationResponse {
tool_name: request.tool.name,
status,
warnings: vec![],
})
}
async fn remove_tool(&self, name: &str) -> Result<Option<McpTool>, McpClientError> {
let mut tools = self.tools.write().await;
Ok(tools.remove(name))
}
async fn invoke_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
_context: InvocationContext,
) -> Result<InvocationResult, McpClientError> {
let tool = self.get_tool(tool_name).await?;
if !self.should_verify_successfully && !tool.verification_status.is_verified() {
return Err(McpClientError::ToolNotVerified {
name: tool_name.to_string(),
});
}
Ok(InvocationResult {
success: true,
result: serde_json::json!({
"status": "success",
"tool": tool_name,
"arguments": arguments
}),
execution_time: Duration::from_millis(50),
warnings: vec![],
metadata: std::collections::HashMap::new(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::integrations::schemapin::types::KeyStoreConfig;
use crate::integrations::schemapin::MockNativeSchemaPinClient;
fn create_test_tool() -> McpTool {
McpTool {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
schema: serde_json::json!({
"type": "object",
"properties": {
"input": {"type": "string"}
}
}),
provider: ToolProvider {
identifier: "example.com".to_string(),
name: "Example Provider".to_string(),
public_key_url: "https://example.com/pubkey".to_string(),
version: Some("1.0.0".to_string()),
},
verification_status: VerificationStatus::Pending,
metadata: None,
sensitive_params: vec![],
}
}
fn create_temp_key_store() -> (LocalKeyStore, tempfile::TempDir) {
let temp_dir = tempfile::TempDir::new().unwrap();
let store_path = temp_dir.path().join("test_keys.json");
let config = KeyStoreConfig {
store_path,
create_if_missing: true,
file_permissions: Some(0o600),
};
let store = LocalKeyStore::with_config(config).unwrap();
(store, temp_dir)
}
#[tokio::test]
async fn test_mock_client_success() {
let client = MockMcpClient::new_success();
let tool = create_test_tool();
let event = client.discover_tool(tool.clone()).await.unwrap();
assert!(event.tool.verification_status.is_verified());
let retrieved_tool = client.get_tool(&tool.name).await.unwrap();
assert_eq!(retrieved_tool.name, tool.name);
}
#[tokio::test]
async fn test_mock_client_failure() {
let client = MockMcpClient::new_failure();
let tool = create_test_tool();
let result = client.discover_tool(tool).await;
assert!(result.is_err());
assert!(matches!(
result,
Err(McpClientError::VerificationFailed { .. })
));
}
#[tokio::test]
async fn test_secure_client_with_mock_components() {
let config = McpClientConfig::default();
let schema_pin = Arc::new(MockNativeSchemaPinClient::new_success());
let (key_store, _temp_dir) = create_temp_key_store();
key_store
.pin_key(PinnedKey::new(
"example.com".to_string(),
"test_key".to_string(),
"ES256".to_string(),
"test_fingerprint".to_string(),
))
.unwrap();
let key_store = Arc::new(key_store);
let client = SecureMcpClient::new(config, schema_pin, key_store);
let tool = create_test_tool();
let event = client.discover_tool(tool.clone()).await.unwrap();
assert!(event.tool.verification_status.is_verified());
let tools = client.list_verified_tools().await.unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, tool.name);
}
#[tokio::test]
async fn test_verification_enforcement() {
let config = McpClientConfig {
enforce_verification: true,
..Default::default()
};
let schema_pin = Arc::new(MockNativeSchemaPinClient::new_failure());
let (key_store, _temp_dir) = create_temp_key_store();
key_store
.pin_key(PinnedKey::new(
"example.com".to_string(),
"test_key".to_string(),
"ES256".to_string(),
"test_fingerprint".to_string(),
))
.unwrap();
let key_store = Arc::new(key_store);
let client = SecureMcpClient::new(config, schema_pin, key_store);
let tool = create_test_tool();
let result = client.discover_tool(tool).await;
assert!(result.is_err());
assert!(matches!(
result,
Err(McpClientError::VerificationFailed { .. })
));
}
}