pub mod security;
use crate::dht::{Key, DHT};
use crate::{PeerId, Result, P2PError};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime, Instant};
use tokio::sync::{RwLock, mpsc, oneshot};
use tokio::time::timeout;
use tracing::{debug, info};
use rand;
pub use security::*;
pub const MCP_VERSION: &str = "2024-11-05";
pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
pub const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(30);
pub const MCP_PROTOCOL: &str = "/p2p-foundation/mcp/1.0.0";
pub const SERVICE_DISCOVERY_INTERVAL: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MCPMessage {
Initialize {
protocol_version: String,
capabilities: MCPCapabilities,
client_info: MCPClientInfo,
},
InitializeResult {
protocol_version: String,
capabilities: MCPCapabilities,
server_info: MCPServerInfo,
},
ListTools {
cursor: Option<String>,
},
ListToolsResult {
tools: Vec<MCPTool>,
next_cursor: Option<String>,
},
CallTool {
name: String,
arguments: Value,
},
CallToolResult {
content: Vec<MCPContent>,
is_error: bool,
},
ListPrompts {
cursor: Option<String>,
},
ListPromptsResult {
prompts: Vec<MCPPrompt>,
next_cursor: Option<String>,
},
GetPrompt {
name: String,
arguments: Option<Value>,
},
GetPromptResult {
description: Option<String>,
messages: Vec<MCPPromptMessage>,
},
ListResources {
cursor: Option<String>,
},
ListResourcesResult {
resources: Vec<MCPResource>,
next_cursor: Option<String>,
},
ReadResource {
uri: String,
},
ReadResourceResult {
contents: Vec<MCPResourceContent>,
},
SubscribeResource {
uri: String,
},
UnsubscribeResource {
uri: String,
},
ResourceUpdated {
uri: String,
},
ListLogs {
cursor: Option<String>,
},
ListLogsResult {
logs: Vec<MCPLogEntry>,
next_cursor: Option<String>,
},
SetLogLevel {
level: MCPLogLevel,
},
Error {
code: i32,
message: String,
data: Option<Value>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPCapabilities {
pub experimental: Option<Value>,
pub sampling: Option<Value>,
pub tools: Option<MCPToolsCapability>,
pub prompts: Option<MCPPromptsCapability>,
pub resources: Option<MCPResourcesCapability>,
pub logging: Option<MCPLoggingCapability>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPToolsCapability {
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPPromptsCapability {
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPResourcesCapability {
pub subscribe: Option<bool>,
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPLoggingCapability {
pub levels: Option<Vec<MCPLogLevel>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPClientInfo {
pub name: String,
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPServerInfo {
pub name: String,
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPTool {
pub name: String,
pub description: String,
pub input_schema: Value,
}
pub struct Tool {
pub definition: MCPTool,
pub handler: Box<dyn ToolHandler + Send + Sync>,
pub metadata: ToolMetadata,
}
#[derive(Debug, Clone)]
pub struct ToolMetadata {
pub created_at: SystemTime,
pub last_called: Option<SystemTime>,
pub call_count: u64,
pub avg_execution_time: Duration,
pub health_status: ToolHealthStatus,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ToolHealthStatus {
Healthy,
Degraded,
Unhealthy,
Disabled,
}
pub trait ToolHandler {
fn execute(&self, arguments: Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value>> + Send + '_>>;
fn validate(&self, arguments: &Value) -> Result<()> {
let _ = arguments;
Ok(())
}
fn get_requirements(&self) -> ToolRequirements {
ToolRequirements::default()
}
}
#[derive(Debug, Clone)]
pub struct ToolRequirements {
pub max_memory: Option<u64>,
pub max_execution_time: Option<Duration>,
pub required_capabilities: Vec<String>,
pub requires_network: bool,
pub requires_filesystem: bool,
}
impl Default for ToolRequirements {
fn default() -> Self {
Self {
max_memory: Some(100 * 1024 * 1024), max_execution_time: Some(Duration::from_secs(30)),
required_capabilities: Vec::new(),
requires_network: false,
requires_filesystem: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MCPContent {
Text {
text: String,
},
Image {
data: String,
mime_type: String,
},
Resource {
resource: MCPResourceReference,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPResourceReference {
pub uri: String,
pub type_: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPPrompt {
pub name: String,
pub description: Option<String>,
pub arguments: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPPromptMessage {
pub role: MCPRole,
pub content: MCPContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MCPRole {
User,
Assistant,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPResource {
pub uri: String,
pub name: String,
pub description: Option<String>,
pub mime_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPResourceContent {
pub uri: String,
pub mime_type: String,
pub text: Option<String>,
pub blob: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum MCPLogLevel {
Debug,
Info,
Notice,
Warning,
Error,
Critical,
Alert,
Emergency,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPLogEntry {
pub level: MCPLogLevel,
pub data: Value,
pub logger: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPService {
pub service_id: String,
pub node_id: PeerId,
pub tools: Vec<String>,
pub capabilities: MCPCapabilities,
pub metadata: MCPServiceMetadata,
pub registered_at: SystemTime,
pub endpoint: MCPEndpoint,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPServiceMetadata {
pub name: String,
pub version: String,
pub description: Option<String>,
pub tags: Vec<String>,
pub health_status: ServiceHealthStatus,
pub load_metrics: ServiceLoadMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ServiceHealthStatus {
Healthy,
Degraded,
Unhealthy,
Maintenance,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceLoadMetrics {
pub active_requests: u32,
pub requests_per_second: f64,
pub avg_response_time_ms: f64,
pub error_rate: f64,
pub cpu_usage: f64,
pub memory_usage: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPEndpoint {
pub protocol: String,
pub address: String,
pub port: Option<u16>,
pub tls: bool,
pub auth_required: bool,
}
#[derive(Debug, Clone)]
pub struct MCPRequest {
pub request_id: String,
pub source_peer: PeerId,
pub target_peer: PeerId,
pub message: MCPMessage,
pub timestamp: SystemTime,
pub timeout: Duration,
pub auth_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct P2PMCPMessage {
pub message_type: P2PMCPMessageType,
pub message_id: String,
pub source_peer: PeerId,
pub target_peer: Option<PeerId>,
pub timestamp: u64,
pub payload: MCPMessage,
pub ttl: u8,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum P2PMCPMessageType {
Request,
Response,
ServiceAdvertisement,
ServiceDiscovery,
}
#[derive(Debug, Clone)]
pub struct MCPResponse {
pub request_id: String,
pub message: MCPMessage,
pub timestamp: SystemTime,
pub processing_time: Duration,
}
#[derive(Debug, Clone)]
pub struct MCPCallContext {
pub caller_id: PeerId,
pub timestamp: SystemTime,
pub timeout: Duration,
pub auth_info: Option<MCPAuthInfo>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct MCPAuthInfo {
pub token: String,
pub token_type: String,
pub expires_at: Option<SystemTime>,
pub permissions: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPServerConfig {
pub server_name: String,
pub server_version: String,
pub enable_dht_discovery: bool,
pub max_concurrent_requests: usize,
pub request_timeout: Duration,
pub enable_auth: bool,
pub enable_rate_limiting: bool,
pub rate_limit_rpm: u32,
pub enable_logging: bool,
pub max_tool_execution_time: Duration,
pub tool_memory_limit: u64,
}
impl Default for MCPServerConfig {
fn default() -> Self {
Self {
server_name: "P2P-MCP-Server".to_string(),
server_version: crate::VERSION.to_string(),
enable_dht_discovery: true,
max_concurrent_requests: 100,
request_timeout: DEFAULT_CALL_TIMEOUT,
enable_auth: true,
enable_rate_limiting: true,
rate_limit_rpm: 60,
enable_logging: true,
max_tool_execution_time: Duration::from_secs(30),
tool_memory_limit: 100 * 1024 * 1024, }
}
}
pub struct MCPServer {
config: MCPServerConfig,
tools: Arc<RwLock<HashMap<String, Tool>>>,
#[allow(dead_code)]
prompts: Arc<RwLock<HashMap<String, MCPPrompt>>>,
#[allow(dead_code)]
resources: Arc<RwLock<HashMap<String, MCPResource>>>,
sessions: Arc<RwLock<HashMap<String, MCPSession>>>,
request_handlers: Arc<RwLock<HashMap<String, oneshot::Sender<MCPResponse>>>>,
dht: Option<Arc<RwLock<DHT>>>,
local_services: Arc<RwLock<HashMap<String, MCPService>>>,
remote_services: Arc<RwLock<HashMap<String, MCPService>>>,
stats: Arc<RwLock<MCPServerStats>>,
request_tx: mpsc::UnboundedSender<MCPRequest>,
#[allow(dead_code)]
response_rx: Arc<RwLock<mpsc::UnboundedReceiver<MCPResponse>>>,
security_manager: Option<Arc<MCPSecurityManager>>,
audit_logger: Arc<SecurityAuditLogger>,
}
#[derive(Debug, Clone)]
pub struct MCPSession {
pub session_id: String,
pub peer_id: PeerId,
pub client_capabilities: Option<MCPCapabilities>,
pub started_at: SystemTime,
pub last_activity: SystemTime,
pub state: MCPSessionState,
pub subscribed_resources: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum MCPSessionState {
Initializing,
Active,
Inactive,
Terminated,
}
#[derive(Debug, Clone)]
pub struct MCPServerStats {
pub total_requests: u64,
pub total_responses: u64,
pub total_errors: u64,
pub avg_response_time: Duration,
pub active_sessions: u32,
pub total_tools: u32,
pub popular_tools: HashMap<String, u64>,
pub server_started_at: SystemTime,
}
impl Default for MCPServerStats {
fn default() -> Self {
Self {
total_requests: 0,
total_responses: 0,
total_errors: 0,
avg_response_time: Duration::from_millis(0),
active_sessions: 0,
total_tools: 0,
popular_tools: HashMap::new(),
server_started_at: SystemTime::now(),
}
}
}
impl MCPServer {
pub fn new(config: MCPServerConfig) -> Self {
let (request_tx, _request_rx) = mpsc::unbounded_channel();
let (_response_tx, response_rx) = mpsc::unbounded_channel();
let security_manager = if config.enable_auth {
let secret_key = (0..32).map(|_| rand::random::<u8>()).collect();
Some(Arc::new(MCPSecurityManager::new(secret_key, config.rate_limit_rpm)))
} else {
None
};
let server = Self {
config,
tools: Arc::new(RwLock::new(HashMap::new())),
prompts: Arc::new(RwLock::new(HashMap::new())),
resources: Arc::new(RwLock::new(HashMap::new())),
sessions: Arc::new(RwLock::new(HashMap::new())),
request_handlers: Arc::new(RwLock::new(HashMap::new())),
dht: None,
local_services: Arc::new(RwLock::new(HashMap::new())),
remote_services: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MCPServerStats::default())),
request_tx,
response_rx: Arc::new(RwLock::new(response_rx)),
security_manager,
audit_logger: Arc::new(SecurityAuditLogger::new(10000)), };
server
}
pub fn with_dht(mut self, dht: Arc<RwLock<DHT>>) -> Self {
self.dht = Some(dht);
self
}
pub async fn start(&self) -> Result<()> {
info!("Starting MCP server: {}", self.config.server_name);
self.start_request_processor().await?;
if self.dht.is_some() {
self.start_service_discovery().await?;
}
self.start_health_monitor().await?;
info!("MCP server started successfully");
Ok(())
}
pub async fn register_tool(&self, tool: Tool) -> Result<()> {
let tool_name = tool.definition.name.clone();
self.validate_tool(&tool).await?;
{
let mut tools = self.tools.write().await;
tools.insert(tool_name.clone(), tool);
}
{
let mut stats = self.stats.write().await;
stats.total_tools += 1;
}
if let Some(dht) = &self.dht {
self.register_tool_in_dht(&tool_name, dht).await?;
}
info!("Registered tool: {}", tool_name);
Ok(())
}
async fn validate_tool(&self, tool: &Tool) -> Result<()> {
let tools = self.tools.read().await;
if tools.contains_key(&tool.definition.name) {
return Err(P2PError::MCP(format!("Tool already exists: {}", tool.definition.name)).into());
}
if tool.definition.name.is_empty() || tool.definition.name.len() > 100 {
return Err(P2PError::MCP("Invalid tool name".to_string()).into());
}
if !tool.definition.input_schema.is_object() {
return Err(P2PError::MCP("Tool input schema must be an object".to_string()).into());
}
Ok(())
}
async fn register_tool_in_dht(&self, tool_name: &str, dht: &Arc<RwLock<DHT>>) -> Result<()> {
let key = Key::new(format!("mcp:tool:{}", tool_name).as_bytes());
let service_info = json!({
"tool_name": tool_name,
"node_id": "local_node", "registered_at": SystemTime::now().duration_since(std::time::UNIX_EPOCH).map_err(|e| P2PError::Network(format!("Time error: {}", e)))?.as_secs(),
"capabilities": self.get_server_capabilities().await
});
let dht_guard = dht.read().await;
dht_guard.put(key, serde_json::to_vec(&service_info)?).await?;
Ok(())
}
async fn get_server_capabilities(&self) -> MCPCapabilities {
MCPCapabilities {
experimental: None,
sampling: None,
tools: Some(MCPToolsCapability {
list_changed: Some(true),
}),
prompts: Some(MCPPromptsCapability {
list_changed: Some(true),
}),
resources: Some(MCPResourcesCapability {
subscribe: Some(true),
list_changed: Some(true),
}),
logging: Some(MCPLoggingCapability {
levels: Some(vec![
MCPLogLevel::Debug,
MCPLogLevel::Info,
MCPLogLevel::Warning,
MCPLogLevel::Error,
]),
}),
}
}
pub async fn call_tool(&self, tool_name: &str, arguments: Value, context: MCPCallContext) -> Result<Value> {
let start_time = Instant::now();
if !self.check_rate_limit(&context.caller_id).await? {
return Err(P2PError::MCP("Rate limit exceeded".to_string()));
}
if !self.check_permission(&context.caller_id, &MCPPermission::ExecuteTools).await? {
return Err(P2PError::MCP("Permission denied: execute tools".to_string()));
}
let tool_security_level = self.get_tool_security_policy(tool_name).await;
let is_trusted = self.is_trusted_peer(&context.caller_id).await;
match tool_security_level {
SecurityLevel::Admin => {
if !self.check_permission(&context.caller_id, &MCPPermission::Admin).await? {
return Err(P2PError::MCP("Permission denied: admin access required".to_string()));
}
}
SecurityLevel::Strong => {
if !is_trusted {
return Err(P2PError::MCP("Permission denied: trusted peer required".to_string()));
}
}
SecurityLevel::Basic => {
if self.config.enable_auth {
if let Some(auth_info) = &context.auth_info {
self.verify_auth_token(&auth_info.token).await?;
} else {
return Err(P2PError::MCP("Authentication required".to_string()));
}
}
}
SecurityLevel::Public => {
}
}
let mut details = HashMap::new();
details.insert("action".to_string(), "tool_call".to_string());
details.insert("tool_name".to_string(), tool_name.to_string());
details.insert("security_level".to_string(), format!("{:?}", tool_security_level));
self.audit_logger.log_event(
"tool_execution".to_string(),
context.caller_id.clone(),
details,
AuditSeverity::Info,
).await;
let tool_exists = {
let tools = self.tools.read().await;
tools.contains_key(tool_name)
};
if !tool_exists {
return Err(P2PError::MCP(format!("Tool not found: {}", tool_name)).into());
}
let requirements = {
let tools = self.tools.read().await;
let tool = tools.get(tool_name).unwrap();
if let Err(e) = tool.handler.validate(&arguments) {
return Err(P2PError::MCP(format!("Tool validation failed: {}", e)).into());
}
tool.handler.get_requirements()
};
self.check_resource_requirements(&requirements).await?;
let tools_clone = self.tools.clone();
let tool_name_owned = tool_name.to_string();
let execution_timeout = context.timeout.min(requirements.max_execution_time.unwrap_or(context.timeout));
let result = timeout(execution_timeout, async move {
let tools = tools_clone.read().await;
let tool = tools.get(&tool_name_owned).unwrap(); tool.handler.execute(arguments).await
}).await
.map_err(|_| P2PError::MCP("Tool execution timeout".to_string()))?
.map_err(|e| P2PError::MCP(format!("Tool execution failed: {}", e)))?;
let execution_time = start_time.elapsed();
self.update_tool_stats(tool_name, execution_time, true).await;
{
let mut stats = self.stats.write().await;
stats.total_requests += 1;
stats.total_responses += 1;
let new_total_time = stats.avg_response_time.mul_f64(stats.total_responses as f64 - 1.0) + execution_time;
stats.avg_response_time = new_total_time.div_f64(stats.total_responses as f64);
*stats.popular_tools.entry(tool_name.to_string()).or_insert(0) += 1;
}
debug!("Tool '{}' executed in {:?}", tool_name, execution_time);
Ok(result)
}
async fn check_resource_requirements(&self, requirements: &ToolRequirements) -> Result<()> {
if let Some(max_memory) = requirements.max_memory {
if max_memory > self.config.tool_memory_limit {
return Err(P2PError::MCP("Tool memory requirement exceeds limit".to_string()).into());
}
}
if let Some(max_execution_time) = requirements.max_execution_time {
if max_execution_time > self.config.max_tool_execution_time {
return Err(P2PError::MCP("Tool execution time requirement exceeds limit".to_string()).into());
}
}
Ok(())
}
async fn update_tool_stats(&self, tool_name: &str, execution_time: Duration, success: bool) {
let mut tools = self.tools.write().await;
if let Some(tool) = tools.get_mut(tool_name) {
tool.metadata.call_count += 1;
tool.metadata.last_called = Some(SystemTime::now());
let new_total_time = tool.metadata.avg_execution_time.mul_f64(tool.metadata.call_count as f64 - 1.0) + execution_time;
tool.metadata.avg_execution_time = new_total_time.div_f64(tool.metadata.call_count as f64);
if !success {
tool.metadata.health_status = match tool.metadata.health_status {
ToolHealthStatus::Healthy => ToolHealthStatus::Degraded,
ToolHealthStatus::Degraded => ToolHealthStatus::Unhealthy,
other => other,
};
} else if tool.metadata.health_status != ToolHealthStatus::Disabled {
tool.metadata.health_status = ToolHealthStatus::Healthy;
}
}
}
pub async fn list_tools(&self, _cursor: Option<String>) -> Result<(Vec<MCPTool>, Option<String>)> {
let tools = self.tools.read().await;
let tool_definitions: Vec<MCPTool> = tools.values()
.map(|tool| tool.definition.clone())
.collect();
Ok((tool_definitions, None))
}
async fn start_request_processor(&self) -> Result<()> {
let _request_tx = self.request_tx.clone();
let _server_clone = Arc::new(self);
tokio::spawn(async move {
info!("MCP request processor started");
loop {
tokio::time::sleep(Duration::from_millis(100)).await;
break;
}
info!("MCP request processor stopped");
});
Ok(())
}
async fn start_service_discovery(&self) -> Result<()> {
if let Some(dht) = self.dht.clone() {
let _stats = self.stats.clone();
let remote_services = self.remote_services.clone();
tokio::spawn(async move {
info!("MCP service discovery started");
loop {
tokio::time::sleep(SERVICE_DISCOVERY_INTERVAL).await;
let key = Key::new(b"mcp:services");
let dht_guard = dht.read().await;
match dht_guard.get(&key).await {
Some(record) => {
match serde_json::from_slice::<Vec<MCPService>>(&record.value) {
Ok(services) => {
debug!("Discovered {} MCP services", services.len());
{
let mut remote_cache = remote_services.write().await;
for service in services {
remote_cache.insert(service.service_id.clone(), service);
}
}
}
Err(e) => {
debug!("Failed to deserialize services: {}", e);
}
}
}
None => {
debug!("No MCP services found in DHT");
}
}
}
});
}
Ok(())
}
async fn start_health_monitor(&self) -> Result<()> {
Ok(())
}
pub async fn get_stats(&self) -> MCPServerStats {
self.stats.read().await.clone()
}
pub async fn discover_remote_services(&self) -> Result<Vec<MCPService>> {
if let Some(dht) = &self.dht {
let key = Key::new(b"mcp:services");
let dht_guard = dht.read().await;
match dht_guard.get(&key).await {
Some(record) => {
match serde_json::from_slice::<Vec<MCPService>>(&record.value) {
Ok(services) => {
{
let mut remote_services = self.remote_services.write().await;
for service in &services {
remote_services.insert(service.service_id.clone(), service.clone());
}
}
Ok(services)
}
Err(e) => {
debug!("Failed to deserialize services: {}", e);
Ok(Vec::new())
}
}
}
None => Ok(Vec::new()),
}
} else {
Ok(Vec::new())
}
}
pub async fn call_remote_tool(&self, peer_id: &PeerId, tool_name: &str, arguments: Value, context: MCPCallContext) -> Result<Value> {
let request_id = uuid::Uuid::new_v4().to_string();
let mcp_message = MCPMessage::CallTool {
name: tool_name.to_string(),
arguments,
};
let p2p_message = P2PMCPMessage {
message_type: P2PMCPMessageType::Request,
message_id: request_id.clone(),
source_peer: context.caller_id.clone(),
target_peer: Some(peer_id.clone()),
timestamp: SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| P2PError::Network(format!("Time error: {}", e)))?
.as_secs(),
payload: mcp_message,
ttl: 5, };
let message_data = serde_json::to_vec(&p2p_message)
.map_err(|e| P2PError::Serialization(e))?;
if message_data.len() > MAX_MESSAGE_SIZE {
return Err(P2PError::MCP("Message too large".to_string()));
}
let (response_tx, _response_rx) = oneshot::channel::<MCPResponse>();
{
let mut handlers = self.request_handlers.write().await;
handlers.insert(request_id.clone(), response_tx);
}
Err(P2PError::MCP("Remote tool calling requires P2P network integration".to_string()))
}
pub async fn handle_p2p_message(&self, message_data: &[u8], source_peer: &PeerId) -> Result<Option<Vec<u8>>> {
let p2p_message: P2PMCPMessage = serde_json::from_slice(message_data)
.map_err(|e| P2PError::Serialization(e))?;
debug!("Received MCP message from {}: {:?}", source_peer, p2p_message.message_type);
match p2p_message.message_type {
P2PMCPMessageType::Request => {
self.handle_remote_request(p2p_message).await
}
P2PMCPMessageType::Response => {
self.handle_remote_response(p2p_message).await?;
Ok(None) }
P2PMCPMessageType::ServiceAdvertisement => {
self.handle_service_advertisement(p2p_message).await?;
Ok(None)
}
P2PMCPMessageType::ServiceDiscovery => {
self.handle_service_discovery(p2p_message).await
}
}
}
async fn handle_remote_request(&self, message: P2PMCPMessage) -> Result<Option<Vec<u8>>> {
match message.payload {
MCPMessage::CallTool { name, arguments } => {
let context = MCPCallContext {
caller_id: message.source_peer.clone(),
timestamp: SystemTime::now(),
timeout: DEFAULT_CALL_TIMEOUT,
auth_info: None,
metadata: HashMap::new(),
};
let result = self.call_tool(&name, arguments, context).await;
let response_payload = match result {
Ok(value) => MCPMessage::CallToolResult {
content: vec![MCPContent::Text { text: value.to_string() }],
is_error: false,
},
Err(e) => MCPMessage::Error {
code: -1,
message: e.to_string(),
data: None,
},
};
let response_message = P2PMCPMessage {
message_type: P2PMCPMessageType::Response,
message_id: message.message_id,
source_peer: "local".to_string(), target_peer: Some(message.source_peer),
timestamp: SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| P2PError::Network(format!("Time error: {}", e)))?
.as_secs(),
payload: response_payload,
ttl: message.ttl.saturating_sub(1),
};
let response_data = serde_json::to_vec(&response_message)
.map_err(|e| P2PError::Serialization(e))?;
Ok(Some(response_data))
}
MCPMessage::ListTools { cursor: _ } => {
let (tools, _) = self.list_tools(None).await?;
let response_payload = MCPMessage::ListToolsResult {
tools,
next_cursor: None,
};
let response_message = P2PMCPMessage {
message_type: P2PMCPMessageType::Response,
message_id: message.message_id,
source_peer: "local".to_string(), target_peer: Some(message.source_peer),
timestamp: SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| P2PError::Network(format!("Time error: {}", e)))?
.as_secs(),
payload: response_payload,
ttl: message.ttl.saturating_sub(1),
};
let response_data = serde_json::to_vec(&response_message)
.map_err(|e| P2PError::Serialization(e))?;
Ok(Some(response_data))
}
_ => {
let error_response = P2PMCPMessage {
message_type: P2PMCPMessageType::Response,
message_id: message.message_id,
source_peer: "local".to_string(), target_peer: Some(message.source_peer),
timestamp: SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| P2PError::Network(format!("Time error: {}", e)))?
.as_secs(),
payload: MCPMessage::Error {
code: -2,
message: "Unsupported request type".to_string(),
data: None,
},
ttl: message.ttl.saturating_sub(1),
};
let response_data = serde_json::to_vec(&error_response)
.map_err(|e| P2PError::Serialization(e))?;
Ok(Some(response_data))
}
}
}
pub async fn generate_auth_token(&self, peer_id: &PeerId, permissions: Vec<MCPPermission>, ttl: Duration) -> Result<String> {
if let Some(security_manager) = &self.security_manager {
let token = security_manager.generate_token(peer_id, permissions, ttl).await?;
let mut details = HashMap::new();
details.insert("action".to_string(), "token_generated".to_string());
details.insert("ttl_seconds".to_string(), ttl.as_secs().to_string());
self.audit_logger.log_event(
"authentication".to_string(),
peer_id.clone(),
details,
AuditSeverity::Info,
).await;
Ok(token)
} else {
Err(P2PError::MCP("Authentication not enabled".to_string()))
}
}
pub async fn verify_auth_token(&self, token: &str) -> Result<TokenPayload> {
if let Some(security_manager) = &self.security_manager {
match security_manager.verify_token(token).await {
Ok(payload) => {
let mut details = HashMap::new();
details.insert("action".to_string(), "token_verified".to_string());
details.insert("subject".to_string(), payload.sub.clone());
self.audit_logger.log_event(
"authentication".to_string(),
payload.iss.clone(),
details,
AuditSeverity::Info,
).await;
Ok(payload)
}
Err(e) => {
let mut details = HashMap::new();
details.insert("action".to_string(), "token_verification_failed".to_string());
details.insert("error".to_string(), e.to_string());
self.audit_logger.log_event(
"authentication".to_string(),
"unknown".to_string(),
details,
AuditSeverity::Warning,
).await;
Err(e)
}
}
} else {
Err(P2PError::MCP("Authentication not enabled".to_string()))
}
}
pub async fn check_permission(&self, peer_id: &PeerId, permission: &MCPPermission) -> Result<bool> {
if let Some(security_manager) = &self.security_manager {
security_manager.check_permission(peer_id, permission).await
} else {
Ok(true)
}
}
pub async fn check_rate_limit(&self, peer_id: &PeerId) -> Result<bool> {
if let Some(security_manager) = &self.security_manager {
let allowed = security_manager.check_rate_limit(peer_id).await?;
if !allowed {
let mut details = HashMap::new();
details.insert("action".to_string(), "rate_limit_exceeded".to_string());
self.audit_logger.log_event(
"rate_limiting".to_string(),
peer_id.clone(),
details,
AuditSeverity::Warning,
).await;
}
Ok(allowed)
} else {
Ok(true)
}
}
pub async fn grant_permission(&self, peer_id: &PeerId, permission: MCPPermission) -> Result<()> {
if let Some(security_manager) = &self.security_manager {
security_manager.grant_permission(peer_id, permission.clone()).await?;
let mut details = HashMap::new();
details.insert("action".to_string(), "permission_granted".to_string());
details.insert("permission".to_string(), permission.as_str().to_string());
self.audit_logger.log_event(
"authorization".to_string(),
peer_id.clone(),
details,
AuditSeverity::Info,
).await;
Ok(())
} else {
Err(P2PError::MCP("Security not enabled".to_string()))
}
}
pub async fn revoke_permission(&self, peer_id: &PeerId, permission: &MCPPermission) -> Result<()> {
if let Some(security_manager) = &self.security_manager {
security_manager.revoke_permission(peer_id, permission).await?;
let mut details = HashMap::new();
details.insert("action".to_string(), "permission_revoked".to_string());
details.insert("permission".to_string(), permission.as_str().to_string());
self.audit_logger.log_event(
"authorization".to_string(),
peer_id.clone(),
details,
AuditSeverity::Info,
).await;
Ok(())
} else {
Err(P2PError::MCP("Security not enabled".to_string()))
}
}
pub async fn add_trusted_peer(&self, peer_id: PeerId) -> Result<()> {
if let Some(security_manager) = &self.security_manager {
security_manager.add_trusted_peer(peer_id.clone()).await?;
let mut details = HashMap::new();
details.insert("action".to_string(), "trusted_peer_added".to_string());
self.audit_logger.log_event(
"trust_management".to_string(),
peer_id,
details,
AuditSeverity::Info,
).await;
Ok(())
} else {
Err(P2PError::MCP("Security not enabled".to_string()))
}
}
pub async fn is_trusted_peer(&self, peer_id: &PeerId) -> bool {
if let Some(security_manager) = &self.security_manager {
security_manager.is_trusted_peer(peer_id).await
} else {
false
}
}
pub async fn set_tool_security_policy(&self, tool_name: String, level: SecurityLevel) -> Result<()> {
if let Some(security_manager) = &self.security_manager {
security_manager.set_tool_policy(tool_name.clone(), level.clone()).await?;
let mut details = HashMap::new();
details.insert("action".to_string(), "tool_policy_set".to_string());
details.insert("tool_name".to_string(), tool_name);
details.insert("security_level".to_string(), format!("{:?}", level));
self.audit_logger.log_event(
"security_policy".to_string(),
"system".to_string(),
details,
AuditSeverity::Info,
).await;
Ok(())
} else {
Err(P2PError::MCP("Security not enabled".to_string()))
}
}
pub async fn get_tool_security_policy(&self, tool_name: &str) -> SecurityLevel {
if let Some(security_manager) = &self.security_manager {
security_manager.get_tool_policy(tool_name).await
} else {
SecurityLevel::Public
}
}
pub async fn get_peer_security_stats(&self, peer_id: &PeerId) -> Option<PeerACL> {
if let Some(security_manager) = &self.security_manager {
security_manager.get_peer_stats(peer_id).await
} else {
None
}
}
pub async fn get_security_audit(&self, limit: Option<usize>) -> Vec<SecurityAuditEntry> {
self.audit_logger.get_recent_entries(limit).await
}
pub async fn security_cleanup(&self) -> Result<()> {
if let Some(security_manager) = &self.security_manager {
security_manager.cleanup().await?;
}
Ok(())
}
async fn handle_remote_response(&self, message: P2PMCPMessage) -> Result<()> {
let response_tx = {
let mut handlers = self.request_handlers.write().await;
handlers.remove(&message.message_id)
};
if let Some(tx) = response_tx {
let response = MCPResponse {
request_id: message.message_id,
message: message.payload,
timestamp: SystemTime::now(),
processing_time: Duration::from_millis(0), };
let _ = tx.send(response);
} else {
debug!("Received response for unknown request: {}", message.message_id);
}
Ok(())
}
async fn handle_service_advertisement(&self, _message: P2PMCPMessage) -> Result<()> {
Ok(())
}
async fn handle_service_discovery(&self, message: P2PMCPMessage) -> Result<Option<Vec<u8>>> {
let local_services: Vec<MCPService> = {
let services = self.local_services.read().await;
services.values().cloned().collect()
};
if !local_services.is_empty() {
let advertisement = P2PMCPMessage {
message_type: P2PMCPMessageType::ServiceAdvertisement,
message_id: uuid::Uuid::new_v4().to_string(),
source_peer: "local".to_string(), target_peer: Some(message.source_peer),
timestamp: SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| P2PError::Network(format!("Time error: {}", e)))?
.as_secs(),
payload: MCPMessage::ListToolsResult {
tools: local_services.into_iter()
.flat_map(|s| s.tools.into_iter().map(|t| MCPTool {
name: t,
description: "Remote tool".to_string(),
input_schema: json!({"type": "object"}),
}))
.collect(),
next_cursor: None,
},
ttl: message.ttl.saturating_sub(1),
};
let response_data = serde_json::to_vec(&advertisement)
.map_err(|e| P2PError::Serialization(e))?;
Ok(Some(response_data))
} else {
Ok(None)
}
}
pub async fn shutdown(&self) -> Result<()> {
info!("Shutting down MCP server");
{
let mut sessions = self.sessions.write().await;
for session in sessions.values_mut() {
session.state = MCPSessionState::Terminated;
}
sessions.clear();
}
info!("MCP server shutdown complete");
Ok(())
}
}
impl Tool {
pub fn new(name: &str, description: &str, input_schema: Value) -> ToolBuilder {
ToolBuilder {
name: name.to_string(),
description: description.to_string(),
input_schema,
handler: None,
tags: Vec::new(),
}
}
}
pub struct ToolBuilder {
name: String,
description: String,
input_schema: Value,
handler: Option<Box<dyn ToolHandler + Send + Sync>>,
tags: Vec<String>,
}
impl ToolBuilder {
pub fn handler<H: ToolHandler + Send + Sync + 'static>(mut self, handler: H) -> Self {
self.handler = Some(Box::new(handler));
self
}
pub fn tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
pub fn build(self) -> Result<Tool> {
let handler = self.handler
.ok_or_else(|| P2PError::MCP("Tool handler is required".to_string()))?;
let definition = MCPTool {
name: self.name,
description: self.description,
input_schema: self.input_schema,
};
let metadata = ToolMetadata {
created_at: SystemTime::now(),
last_called: None,
call_count: 0,
avg_execution_time: Duration::from_millis(0),
health_status: ToolHealthStatus::Healthy,
tags: self.tags,
};
Ok(Tool {
definition,
handler,
metadata,
})
}
}
pub struct FunctionToolHandler<F> {
function: F,
}
impl<F, Fut> ToolHandler for FunctionToolHandler<F>
where
F: Fn(Value) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<Value>> + Send + 'static,
{
fn execute(&self, arguments: Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value>> + Send + '_>> {
Box::pin((self.function)(arguments))
}
}
impl<F> FunctionToolHandler<F> {
pub fn new(function: F) -> Self {
Self { function }
}
}
impl MCPService {
pub fn new(service_id: String, node_id: PeerId) -> Self {
Self {
service_id,
node_id,
tools: Vec::new(),
capabilities: MCPCapabilities {
experimental: None,
sampling: None,
tools: Some(MCPToolsCapability {
list_changed: Some(true),
}),
prompts: None,
resources: None,
logging: None,
},
metadata: MCPServiceMetadata {
name: "MCP Service".to_string(),
version: "1.0.0".to_string(),
description: None,
tags: Vec::new(),
health_status: ServiceHealthStatus::Healthy,
load_metrics: ServiceLoadMetrics {
active_requests: 0,
requests_per_second: 0.0,
avg_response_time_ms: 0.0,
error_rate: 0.0,
cpu_usage: 0.0,
memory_usage: 0,
},
},
registered_at: SystemTime::now(),
endpoint: MCPEndpoint {
protocol: "p2p".to_string(),
address: "".to_string(),
port: None,
tls: false,
auth_required: false,
},
}
}
}
impl Default for MCPCapabilities {
fn default() -> Self {
Self {
experimental: None,
sampling: None,
tools: Some(MCPToolsCapability {
list_changed: Some(true),
}),
prompts: Some(MCPPromptsCapability {
list_changed: Some(true),
}),
resources: Some(MCPResourcesCapability {
subscribe: Some(true),
list_changed: Some(true),
}),
logging: Some(MCPLoggingCapability {
levels: Some(vec![
MCPLogLevel::Debug,
MCPLogLevel::Info,
MCPLogLevel::Warning,
MCPLogLevel::Error,
]),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dht::{DHT, DHTConfig, Key};
use std::pin::Pin;
use std::future::Future;
use tokio::time::timeout;
struct TestTool {
name: String,
should_error: bool,
execution_time: Duration,
}
impl TestTool {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
should_error: false,
execution_time: Duration::from_millis(10),
}
}
fn with_error(mut self) -> Self {
self.should_error = true;
self
}
fn with_execution_time(mut self, duration: Duration) -> Self {
self.execution_time = duration;
self
}
}
impl ToolHandler for TestTool {
fn execute(&self, arguments: Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + '_>> {
let should_error = self.should_error;
let execution_time = self.execution_time;
let name = self.name.clone();
Box::pin(async move {
tokio::time::sleep(execution_time).await;
if should_error {
return Err(P2PError::MCP(format!("Test error from tool {}", name)).into());
}
Ok(json!({
"tool": name,
"arguments": arguments,
"result": "success"
}))
})
}
fn validate(&self, arguments: &Value) -> Result<()> {
if !arguments.is_object() {
return Err(P2PError::MCP("Arguments must be an object".to_string()).into());
}
Ok(())
}
fn get_requirements(&self) -> ToolRequirements {
ToolRequirements {
max_memory: Some(1024 * 1024), max_execution_time: Some(Duration::from_secs(5)),
required_capabilities: vec!["test".to_string()],
requires_network: false,
requires_filesystem: false,
}
}
}
async fn create_test_mcp_server() -> MCPServer {
let config = MCPServerConfig {
server_name: "test_server".to_string(),
server_version: "1.0.0".to_string(),
enable_auth: false,
enable_rate_limiting: false,
max_concurrent_requests: 10,
request_timeout: Duration::from_secs(30),
enable_dht_discovery: true,
rate_limit_rpm: 60,
enable_logging: true,
max_tool_execution_time: Duration::from_secs(30),
tool_memory_limit: 100 * 1024 * 1024,
};
MCPServer::new(config)
}
fn create_test_tool(name: &str) -> Tool {
Tool {
definition: MCPTool {
name: name.to_string(),
description: format!("Test tool: {}", name),
input_schema: json!({
"type": "object",
"properties": {
"input": { "type": "string" }
}
}),
},
handler: Box::new(TestTool::new(name)),
metadata: ToolMetadata {
created_at: SystemTime::now(),
last_called: None,
call_count: 0,
avg_execution_time: Duration::from_millis(0),
health_status: ToolHealthStatus::Healthy,
tags: vec!["test".to_string()],
},
}
}
async fn create_test_dht() -> DHT {
let local_id = Key::new(b"test_node_id");
let config = DHTConfig::default();
DHT::new(local_id, config)
}
fn create_test_context(caller_id: PeerId) -> MCPCallContext {
MCPCallContext {
caller_id,
timestamp: SystemTime::now(),
timeout: Duration::from_secs(30),
auth_info: None,
metadata: HashMap::new(),
}
}
#[tokio::test]
async fn test_mcp_server_creation() {
let server = create_test_mcp_server().await;
assert_eq!(server.config.server_name, "test_server");
assert_eq!(server.config.server_version, "1.0.0");
assert!(!server.config.enable_auth);
assert!(!server.config.enable_rate_limiting);
}
#[tokio::test]
async fn test_tool_registration() -> Result<()> {
let server = create_test_mcp_server().await;
let tool = create_test_tool("test_calculator");
server.register_tool(tool).await?;
let tools = server.tools.read().await;
assert!(tools.contains_key("test_calculator"));
assert_eq!(tools.get("test_calculator").unwrap().definition.name, "test_calculator");
let stats = server.stats.read().await;
assert_eq!(stats.total_tools, 1);
Ok(())
}
#[tokio::test]
async fn test_tool_registration_duplicate() -> Result<()> {
let server = create_test_mcp_server().await;
let tool1 = create_test_tool("duplicate_tool");
let tool2 = create_test_tool("duplicate_tool");
server.register_tool(tool1).await?;
let result = server.register_tool(tool2).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Tool already exists"));
Ok(())
}
#[tokio::test]
async fn test_tool_validation() {
let server = create_test_mcp_server().await;
let mut invalid_tool = create_test_tool("");
let result = server.validate_tool(&invalid_tool).await;
assert!(result.is_err());
invalid_tool.definition.name = "a".repeat(200);
let result = server.validate_tool(&invalid_tool).await;
assert!(result.is_err());
let mut invalid_schema_tool = create_test_tool("valid_name");
invalid_schema_tool.definition.input_schema = json!("not an object");
let result = server.validate_tool(&invalid_schema_tool).await;
assert!(result.is_err());
let valid_tool = create_test_tool("valid_tool");
let result = server.validate_tool(&valid_tool).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_tool_call_success() -> Result<()> {
let server = create_test_mcp_server().await;
let tool = create_test_tool("success_tool");
server.register_tool(tool).await?;
let caller_id = "test_peer_123".to_string();
let context = create_test_context(caller_id);
let arguments = json!({"input": "test data"});
let result = server.call_tool("success_tool", arguments.clone(), context).await?;
assert_eq!(result["tool"], "success_tool");
assert_eq!(result["arguments"], arguments);
assert_eq!(result["result"], "success");
let tools = server.tools.read().await;
let tool_metadata = &tools.get("success_tool").unwrap().metadata;
assert_eq!(tool_metadata.call_count, 1);
assert!(tool_metadata.last_called.is_some());
Ok(())
}
#[tokio::test]
async fn test_tool_call_nonexistent() -> Result<()> {
let server = create_test_mcp_server().await;
let caller_id = "test_peer_456".to_string();
let context = create_test_context(caller_id);
let arguments = json!({"input": "test"});
let result = server.call_tool("nonexistent_tool", arguments, context).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Tool not found"));
Ok(())
}
#[tokio::test]
async fn test_tool_call_handler_error() -> Result<()> {
let server = create_test_mcp_server().await;
let tool = Tool {
definition: MCPTool {
name: "error_tool".to_string(),
description: "Tool that always errors".to_string(),
input_schema: json!({"type": "object"}),
},
handler: Box::new(TestTool::new("error_tool").with_error()),
metadata: ToolMetadata {
created_at: SystemTime::now(),
last_called: None,
call_count: 0,
avg_execution_time: Duration::from_millis(0),
health_status: ToolHealthStatus::Healthy,
tags: vec![],
},
};
server.register_tool(tool).await?;
let caller_id = "test_peer_error".to_string();
let context = create_test_context(caller_id);
let arguments = json!({"input": "test"});
let result = server.call_tool("error_tool", arguments, context).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Test error from tool error_tool"));
Ok(())
}
#[tokio::test]
async fn test_tool_call_timeout() -> Result<()> {
let server = create_test_mcp_server().await;
let slow_tool = Tool {
definition: MCPTool {
name: "slow_tool".to_string(),
description: "Tool that takes too long".to_string(),
input_schema: json!({"type": "object"}),
},
handler: Box::new(TestTool::new("slow_tool").with_execution_time(Duration::from_secs(2))),
metadata: ToolMetadata {
created_at: SystemTime::now(),
last_called: None,
call_count: 0,
avg_execution_time: Duration::from_millis(0),
health_status: ToolHealthStatus::Healthy,
tags: vec![],
},
};
server.register_tool(slow_tool).await?;
let caller_id = "test_peer_error".to_string();
let context = create_test_context(caller_id);
let arguments = json!({"input": "test"});
let result = timeout(
Duration::from_millis(100),
server.call_tool("slow_tool", arguments, context)
).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_tool_requirements() {
let tool = TestTool::new("req_tool");
let requirements = tool.get_requirements();
assert_eq!(requirements.max_memory, Some(1024 * 1024));
assert_eq!(requirements.max_execution_time, Some(Duration::from_secs(5)));
assert_eq!(requirements.required_capabilities, vec!["test"]);
assert!(!requirements.requires_network);
assert!(!requirements.requires_filesystem);
}
#[tokio::test]
async fn test_tool_validation_handler() {
let tool = TestTool::new("validation_tool");
let valid_args = json!({"key": "value"});
assert!(tool.validate(&valid_args).is_ok());
let invalid_args = json!("not an object");
assert!(tool.validate(&invalid_args).is_err());
let invalid_args = json!(123);
assert!(tool.validate(&invalid_args).is_err());
}
#[tokio::test]
async fn test_tool_health_status() {
let mut metadata = ToolMetadata {
created_at: SystemTime::now(),
last_called: None,
call_count: 0,
avg_execution_time: Duration::from_millis(0),
health_status: ToolHealthStatus::Healthy,
tags: vec![],
};
assert_eq!(metadata.health_status, ToolHealthStatus::Healthy);
metadata.health_status = ToolHealthStatus::Degraded;
assert_eq!(metadata.health_status, ToolHealthStatus::Degraded);
metadata.health_status = ToolHealthStatus::Unhealthy;
assert_eq!(metadata.health_status, ToolHealthStatus::Unhealthy);
metadata.health_status = ToolHealthStatus::Disabled;
assert_eq!(metadata.health_status, ToolHealthStatus::Disabled);
}
#[tokio::test]
async fn test_mcp_capabilities() {
let server = create_test_mcp_server().await;
let capabilities = server.get_server_capabilities().await;
assert!(capabilities.tools.is_some());
assert!(capabilities.prompts.is_some());
assert!(capabilities.resources.is_some());
assert!(capabilities.logging.is_some());
let tools_cap = capabilities.tools.unwrap();
assert_eq!(tools_cap.list_changed, Some(true));
let logging_cap = capabilities.logging.unwrap();
let levels = logging_cap.levels.unwrap();
assert!(levels.contains(&MCPLogLevel::Debug));
assert!(levels.contains(&MCPLogLevel::Info));
assert!(levels.contains(&MCPLogLevel::Warning));
assert!(levels.contains(&MCPLogLevel::Error));
}
#[tokio::test]
async fn test_mcp_message_serialization() {
let init_msg = MCPMessage::Initialize {
protocol_version: MCP_VERSION.to_string(),
capabilities: MCPCapabilities {
experimental: None,
sampling: None,
tools: Some(MCPToolsCapability { list_changed: Some(true) }),
prompts: None,
resources: None,
logging: None,
},
client_info: MCPClientInfo {
name: "test_client".to_string(),
version: "1.0.0".to_string(),
},
};
let serialized = serde_json::to_string(&init_msg).unwrap();
let deserialized: MCPMessage = serde_json::from_str(&serialized).unwrap();
match deserialized {
MCPMessage::Initialize { protocol_version, client_info, .. } => {
assert_eq!(protocol_version, MCP_VERSION);
assert_eq!(client_info.name, "test_client");
assert_eq!(client_info.version, "1.0.0");
}
_ => panic!("Wrong message type after deserialization"),
}
}
#[tokio::test]
async fn test_mcp_content_types() {
let text_content = MCPContent::Text {
text: "Hello, world!".to_string(),
};
let serialized = serde_json::to_string(&text_content).unwrap();
let deserialized: MCPContent = serde_json::from_str(&serialized).unwrap();
match deserialized {
MCPContent::Text { text } => assert_eq!(text, "Hello, world!"),
_ => panic!("Wrong content type"),
}
let image_content = MCPContent::Image {
data: "base64data".to_string(),
mime_type: "image/png".to_string(),
};
let serialized = serde_json::to_string(&image_content).unwrap();
let deserialized: MCPContent = serde_json::from_str(&serialized).unwrap();
match deserialized {
MCPContent::Image { data, mime_type } => {
assert_eq!(data, "base64data");
assert_eq!(mime_type, "image/png");
}
_ => panic!("Wrong content type"),
}
}
#[tokio::test]
async fn test_service_health_status() {
let mut metrics = ServiceLoadMetrics {
active_requests: 0,
requests_per_second: 0.0,
avg_response_time_ms: 0.0,
error_rate: 0.0,
cpu_usage: 0.0,
memory_usage: 0,
};
let metadata = MCPServiceMetadata {
name: "test_service".to_string(),
version: "1.0.0".to_string(),
description: Some("Test service".to_string()),
tags: vec!["test".to_string()],
health_status: ServiceHealthStatus::Healthy,
load_metrics: metrics.clone(),
};
assert_eq!(metadata.health_status, ServiceHealthStatus::Healthy);
metrics.error_rate = 0.5; let degraded_metadata = MCPServiceMetadata {
health_status: ServiceHealthStatus::Degraded,
load_metrics: metrics.clone(),
..metadata.clone()
};
assert_eq!(degraded_metadata.health_status, ServiceHealthStatus::Degraded);
let unhealthy_metadata = MCPServiceMetadata {
health_status: ServiceHealthStatus::Unhealthy,
..metadata.clone()
};
assert_eq!(unhealthy_metadata.health_status, ServiceHealthStatus::Unhealthy);
}
#[tokio::test]
async fn test_p2p_mcp_message() {
let source_peer = "source_peer_123".to_string();
let target_peer = "target_peer_456".to_string();
let p2p_message = P2PMCPMessage {
message_type: P2PMCPMessageType::Request,
message_id: uuid::Uuid::new_v4().to_string(),
source_peer: source_peer.clone(),
target_peer: Some(target_peer.clone()),
timestamp: SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
payload: MCPMessage::ListTools { cursor: None },
ttl: 10,
};
let serialized = serde_json::to_string(&p2p_message).unwrap();
let deserialized: P2PMCPMessage = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.message_type, P2PMCPMessageType::Request);
assert_eq!(deserialized.source_peer, source_peer);
assert_eq!(deserialized.target_peer, Some(target_peer));
assert_eq!(deserialized.ttl, 10);
match deserialized.payload {
MCPMessage::ListTools { cursor } => assert_eq!(cursor, None),
_ => panic!("Wrong message payload type"),
}
}
#[tokio::test]
async fn test_tool_requirements_default() {
let default_requirements = ToolRequirements::default();
assert_eq!(default_requirements.max_memory, Some(100 * 1024 * 1024));
assert_eq!(default_requirements.max_execution_time, Some(Duration::from_secs(30)));
assert!(default_requirements.required_capabilities.is_empty());
assert!(!default_requirements.requires_network);
assert!(!default_requirements.requires_filesystem);
}
#[tokio::test]
async fn test_mcp_server_stats() {
let server = create_test_mcp_server().await;
let stats = server.stats.read().await;
assert_eq!(stats.total_tools, 0);
assert_eq!(stats.total_requests, 0);
assert_eq!(stats.total_responses, 0);
assert_eq!(stats.total_errors, 0);
drop(stats);
let tool = create_test_tool("stats_test_tool");
server.register_tool(tool).await.unwrap();
let stats = server.stats.read().await;
assert_eq!(stats.total_tools, 1);
}
#[tokio::test]
async fn test_log_levels() {
let levels = vec![
MCPLogLevel::Debug,
MCPLogLevel::Info,
MCPLogLevel::Notice,
MCPLogLevel::Warning,
MCPLogLevel::Error,
MCPLogLevel::Critical,
MCPLogLevel::Alert,
MCPLogLevel::Emergency,
];
for level in levels {
let serialized = serde_json::to_string(&level).unwrap();
let deserialized: MCPLogLevel = serde_json::from_str(&serialized).unwrap();
assert_eq!(level as u8, deserialized as u8);
}
}
#[tokio::test]
async fn test_mcp_endpoint() {
let endpoint = MCPEndpoint {
protocol: "p2p".to_string(),
address: "127.0.0.1".to_string(),
port: Some(9000),
tls: true,
auth_required: true,
};
let serialized = serde_json::to_string(&endpoint).unwrap();
let deserialized: MCPEndpoint = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.protocol, "p2p");
assert_eq!(deserialized.address, "127.0.0.1");
assert_eq!(deserialized.port, Some(9000));
assert!(deserialized.tls);
assert!(deserialized.auth_required);
}
#[tokio::test]
async fn test_mcp_service_metadata() {
let load_metrics = ServiceLoadMetrics {
active_requests: 5,
requests_per_second: 10.5,
avg_response_time_ms: 250.0,
error_rate: 0.01,
cpu_usage: 45.5,
memory_usage: 1024 * 1024 * 100, };
let metadata = MCPServiceMetadata {
name: "test_service".to_string(),
version: "2.1.0".to_string(),
description: Some("A test service for unit testing".to_string()),
tags: vec!["test".to_string(), "unit".to_string(), "mcp".to_string()],
health_status: ServiceHealthStatus::Healthy,
load_metrics,
};
let serialized = serde_json::to_string(&metadata).unwrap();
let deserialized: MCPServiceMetadata = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.name, "test_service");
assert_eq!(deserialized.version, "2.1.0");
assert_eq!(deserialized.description, Some("A test service for unit testing".to_string()));
assert_eq!(deserialized.tags, vec!["test", "unit", "mcp"]);
assert_eq!(deserialized.health_status, ServiceHealthStatus::Healthy);
assert_eq!(deserialized.load_metrics.active_requests, 5);
assert_eq!(deserialized.load_metrics.requests_per_second, 10.5);
}
#[tokio::test]
async fn test_function_tool_handler() {
let handler = FunctionToolHandler::new(|args: Value| async move {
let name = args.get("name").and_then(|v| v.as_str()).unwrap_or("world");
Ok(json!({"greeting": format!("Hello, {}!", name)}))
});
let args = json!({"name": "Alice"});
let result = handler.execute(args).await.unwrap();
assert_eq!(result["greeting"], "Hello, Alice!");
let empty_args = json!({});
let result = handler.execute(empty_args).await.unwrap();
assert_eq!(result["greeting"], "Hello, world!");
}
#[tokio::test]
async fn test_mcp_service_creation() {
let service_id = "test_service_123".to_string();
let node_id = "test_node_789".to_string();
let service = MCPService::new(service_id.clone(), node_id.clone());
assert_eq!(service.service_id, service_id);
assert_eq!(service.node_id, node_id);
assert!(service.tools.is_empty());
assert_eq!(service.metadata.name, "MCP Service");
assert_eq!(service.metadata.version, "1.0.0");
assert_eq!(service.metadata.health_status, ServiceHealthStatus::Healthy);
assert_eq!(service.endpoint.protocol, "p2p");
assert!(!service.endpoint.tls);
assert!(!service.endpoint.auth_required);
}
#[tokio::test]
async fn test_mcp_capabilities_default() {
let capabilities = MCPCapabilities::default();
assert!(capabilities.tools.is_some());
assert!(capabilities.prompts.is_some());
assert!(capabilities.resources.is_some());
assert!(capabilities.logging.is_some());
let tools_cap = capabilities.tools.unwrap();
assert_eq!(tools_cap.list_changed, Some(true));
let resources_cap = capabilities.resources.unwrap();
assert_eq!(resources_cap.subscribe, Some(true));
assert_eq!(resources_cap.list_changed, Some(true));
let logging_cap = capabilities.logging.unwrap();
let levels = logging_cap.levels.unwrap();
assert!(levels.contains(&MCPLogLevel::Debug));
assert!(levels.contains(&MCPLogLevel::Info));
assert!(levels.contains(&MCPLogLevel::Warning));
assert!(levels.contains(&MCPLogLevel::Error));
}
#[tokio::test]
async fn test_mcp_request_creation() {
let source_peer = "source_peer_123".to_string();
let target_peer = "target_peer_456".to_string();
let request = MCPRequest {
request_id: uuid::Uuid::new_v4().to_string(),
source_peer: source_peer.clone(),
target_peer: target_peer.clone(),
message: MCPMessage::ListTools { cursor: None },
timestamp: SystemTime::now(),
timeout: Duration::from_secs(30),
auth_token: Some("test_token".to_string()),
};
assert_eq!(request.source_peer, source_peer);
assert_eq!(request.target_peer, target_peer);
assert_eq!(request.timeout, Duration::from_secs(30));
assert_eq!(request.auth_token, Some("test_token".to_string()));
match request.message {
MCPMessage::ListTools { cursor } => assert_eq!(cursor, None),
_ => panic!("Wrong message type"),
}
}
#[tokio::test]
async fn test_p2p_message_types() {
assert_eq!(P2PMCPMessageType::Request, P2PMCPMessageType::Request);
assert_eq!(P2PMCPMessageType::Response, P2PMCPMessageType::Response);
assert_eq!(P2PMCPMessageType::ServiceAdvertisement, P2PMCPMessageType::ServiceAdvertisement);
assert_eq!(P2PMCPMessageType::ServiceDiscovery, P2PMCPMessageType::ServiceDiscovery);
for msg_type in [
P2PMCPMessageType::Request,
P2PMCPMessageType::Response,
P2PMCPMessageType::ServiceAdvertisement,
P2PMCPMessageType::ServiceDiscovery,
] {
let serialized = serde_json::to_string(&msg_type).unwrap();
let deserialized: P2PMCPMessageType = serde_json::from_str(&serialized).unwrap();
assert_eq!(msg_type, deserialized);
}
}
}