use super::*;
use crate::{McpToolsError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
pub const MCP_PROTOCOL_VERSION: &str = "1.0";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum McpMessageType {
Initialize,
InitializeResult,
GetCapabilities,
CapabilitiesResult,
CallTool,
ToolResult,
Notification,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpMessage {
pub id: Option<String>,
#[serde(rename = "type")]
pub message_type: McpMessageType,
pub payload: serde_json::Value,
pub version: String,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeRequest {
pub client_info: ClientInfo,
pub capabilities: ClientCapabilities,
pub protocol_version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeResponse {
pub server_info: ServerInfo,
pub capabilities: ServerCapabilities,
pub protocol_version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CallToolRequest {
pub name: String,
pub arguments: serde_json::Value,
pub session_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultResponse {
pub content: Vec<McpContent>,
pub is_error: bool,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub code: i32,
pub message: String,
pub data: Option<serde_json::Value>,
}
pub struct McpProtocol {
version: String,
message_counter: std::sync::atomic::AtomicU64,
}
impl McpProtocol {
pub fn new() -> Self {
Self {
version: MCP_PROTOCOL_VERSION.to_string(),
message_counter: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn generate_message_id(&self) -> String {
let counter = self
.message_counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
format!(
"msg-{}-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis(),
counter
)
}
pub fn create_initialize_request(
&self,
client_info: ClientInfo,
capabilities: ClientCapabilities,
) -> McpMessage {
let payload = InitializeRequest {
client_info,
capabilities,
protocol_version: self.version.clone(),
};
McpMessage {
id: Some(self.generate_message_id()),
message_type: McpMessageType::Initialize,
payload: serde_json::to_value(payload).unwrap(),
version: self.version.clone(),
metadata: HashMap::new(),
}
}
pub fn create_initialize_response(
&self,
request_id: &str,
server_info: ServerInfo,
capabilities: ServerCapabilities,
) -> McpMessage {
let payload = InitializeResponse {
server_info,
capabilities,
protocol_version: self.version.clone(),
};
McpMessage {
id: Some(request_id.to_string()),
message_type: McpMessageType::InitializeResult,
payload: serde_json::to_value(payload).unwrap(),
version: self.version.clone(),
metadata: HashMap::new(),
}
}
pub fn create_tool_call_request(
&self,
tool_name: &str,
arguments: serde_json::Value,
session_id: &str,
) -> McpMessage {
let payload = CallToolRequest {
name: tool_name.to_string(),
arguments,
session_id: session_id.to_string(),
};
McpMessage {
id: Some(self.generate_message_id()),
message_type: McpMessageType::CallTool,
payload: serde_json::to_value(payload).unwrap(),
version: self.version.clone(),
metadata: HashMap::new(),
}
}
pub fn create_tool_result_response(
&self,
request_id: &str,
content: Vec<McpContent>,
is_error: bool,
error: Option<String>,
) -> McpMessage {
let payload = ToolResultResponse {
content,
is_error,
error,
};
McpMessage {
id: Some(request_id.to_string()),
message_type: McpMessageType::ToolResult,
payload: serde_json::to_value(payload).unwrap(),
version: self.version.clone(),
metadata: HashMap::new(),
}
}
pub fn create_error_response(
&self,
request_id: Option<&str>,
code: i32,
message: &str,
data: Option<serde_json::Value>,
) -> McpMessage {
let payload = ErrorResponse {
code,
message: message.to_string(),
data,
};
McpMessage {
id: request_id.map(|s| s.to_string()),
message_type: McpMessageType::Error,
payload: serde_json::to_value(payload).unwrap(),
version: self.version.clone(),
metadata: HashMap::new(),
}
}
pub fn parse_message(&self, json: &str) -> Result<McpMessage> {
serde_json::from_str(json).map_err(|e| McpToolsError::Serialization(e))
}
pub fn serialize_message(&self, message: &McpMessage) -> Result<String> {
serde_json::to_string(message).map_err(|e| McpToolsError::Serialization(e))
}
pub fn validate_version(&self, message: &McpMessage) -> Result<()> {
if message.version != self.version {
return Err(McpToolsError::Server(format!(
"Protocol version mismatch: expected {}, got {}",
self.version, message.version
)));
}
Ok(())
}
}
impl Default for McpProtocol {
fn default() -> Self {
Self::new()
}
}
pub mod error_codes {
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
pub const PERMISSION_DENIED: i32 = -32000;
pub const TOOL_NOT_FOUND: i32 = -32001;
pub const TOOL_EXECUTION_ERROR: i32 = -32002;
pub const SESSION_ERROR: i32 = -32003;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protocol_creation() {
let protocol = McpProtocol::new();
assert_eq!(protocol.version, MCP_PROTOCOL_VERSION);
}
#[test]
fn test_message_id_generation() {
let protocol = McpProtocol::new();
let id1 = protocol.generate_message_id();
let id2 = protocol.generate_message_id();
assert_ne!(id1, id2);
assert!(id1.starts_with("msg-"));
assert!(id2.starts_with("msg-"));
}
#[test]
fn test_initialize_request_creation() {
let protocol = McpProtocol::new();
let client_info = ClientInfo {
name: "Test Client".to_string(),
version: "1.0.0".to_string(),
description: "Test".to_string(),
};
let capabilities = ClientCapabilities {
content_types: vec!["text".to_string()],
features: vec!["test".to_string()],
info: client_info.clone(),
};
let message = protocol.create_initialize_request(client_info, capabilities);
assert_eq!(message.message_type, McpMessageType::Initialize);
assert!(message.id.is_some());
assert_eq!(message.version, MCP_PROTOCOL_VERSION);
}
#[test]
fn test_message_serialization() {
let protocol = McpProtocol::new();
let message = McpMessage {
id: Some("test-id".to_string()),
message_type: McpMessageType::Notification,
payload: serde_json::json!({"test": "data"}),
version: MCP_PROTOCOL_VERSION.to_string(),
metadata: HashMap::new(),
};
let json = protocol.serialize_message(&message).unwrap();
let parsed = protocol.parse_message(&json).unwrap();
assert_eq!(parsed.id, message.id);
assert_eq!(parsed.message_type, message.message_type);
assert_eq!(parsed.version, message.version);
}
}