mcp-tools 0.1.0

Rust MCP tools library
Documentation
//! MCP Protocol implementation for MCP Tools

use super::*;
use crate::{McpToolsError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;

/// MCP Protocol version
pub const MCP_PROTOCOL_VERSION: &str = "1.0";

/// MCP message types
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum McpMessageType {
    // Initialization
    Initialize,
    InitializeResult,

    // Capabilities
    GetCapabilities,
    CapabilitiesResult,

    // Tool execution
    CallTool,
    ToolResult,

    // Notifications
    Notification,

    // Errors
    Error,
}

/// MCP protocol message
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpMessage {
    /// Message ID (for request-response correlation)
    pub id: Option<String>,

    /// Message type
    #[serde(rename = "type")]
    pub message_type: McpMessageType,

    /// Message payload
    pub payload: serde_json::Value,

    /// Protocol version
    pub version: String,

    /// Message metadata
    #[serde(default)]
    pub metadata: HashMap<String, serde_json::Value>,
}

/// Initialize request payload
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeRequest {
    /// Client information
    pub client_info: ClientInfo,

    /// Client capabilities
    pub capabilities: ClientCapabilities,

    /// Protocol version
    pub protocol_version: String,
}

/// Initialize response payload
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeResponse {
    /// Server information
    pub server_info: ServerInfo,

    /// Server capabilities
    pub capabilities: ServerCapabilities,

    /// Protocol version
    pub protocol_version: String,
}

/// Tool call request payload
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CallToolRequest {
    /// Tool name
    pub name: String,

    /// Tool arguments
    pub arguments: serde_json::Value,

    /// Session ID
    pub session_id: String,
}

/// Tool call response payload
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultResponse {
    /// Tool execution result
    pub content: Vec<McpContent>,

    /// Whether the operation was successful
    pub is_error: bool,

    /// Error message if any
    pub error: Option<String>,
}

/// Error response payload
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
    /// Error code
    pub code: i32,

    /// Error message
    pub message: String,

    /// Additional error data
    pub data: Option<serde_json::Value>,
}

/// MCP Protocol handler
pub struct McpProtocol {
    /// Protocol version
    version: String,

    /// Message ID counter
    message_counter: std::sync::atomic::AtomicU64,
}

impl McpProtocol {
    /// Create new MCP protocol handler
    pub fn new() -> Self {
        Self {
            version: MCP_PROTOCOL_VERSION.to_string(),
            message_counter: std::sync::atomic::AtomicU64::new(0),
        }
    }

    /// Generate unique message ID
    pub fn generate_message_id(&self) -> String {
        let counter = self
            .message_counter
            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
        format!(
            "msg-{}-{}",
            std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap()
                .as_millis(),
            counter
        )
    }

    /// Create initialize request message
    pub fn create_initialize_request(
        &self,
        client_info: ClientInfo,
        capabilities: ClientCapabilities,
    ) -> McpMessage {
        let payload = InitializeRequest {
            client_info,
            capabilities,
            protocol_version: self.version.clone(),
        };

        McpMessage {
            id: Some(self.generate_message_id()),
            message_type: McpMessageType::Initialize,
            payload: serde_json::to_value(payload).unwrap(),
            version: self.version.clone(),
            metadata: HashMap::new(),
        }
    }

    /// Create initialize response message
    pub fn create_initialize_response(
        &self,
        request_id: &str,
        server_info: ServerInfo,
        capabilities: ServerCapabilities,
    ) -> McpMessage {
        let payload = InitializeResponse {
            server_info,
            capabilities,
            protocol_version: self.version.clone(),
        };

        McpMessage {
            id: Some(request_id.to_string()),
            message_type: McpMessageType::InitializeResult,
            payload: serde_json::to_value(payload).unwrap(),
            version: self.version.clone(),
            metadata: HashMap::new(),
        }
    }

    /// Create tool call request message
    pub fn create_tool_call_request(
        &self,
        tool_name: &str,
        arguments: serde_json::Value,
        session_id: &str,
    ) -> McpMessage {
        let payload = CallToolRequest {
            name: tool_name.to_string(),
            arguments,
            session_id: session_id.to_string(),
        };

        McpMessage {
            id: Some(self.generate_message_id()),
            message_type: McpMessageType::CallTool,
            payload: serde_json::to_value(payload).unwrap(),
            version: self.version.clone(),
            metadata: HashMap::new(),
        }
    }

    /// Create tool result response message
    pub fn create_tool_result_response(
        &self,
        request_id: &str,
        content: Vec<McpContent>,
        is_error: bool,
        error: Option<String>,
    ) -> McpMessage {
        let payload = ToolResultResponse {
            content,
            is_error,
            error,
        };

        McpMessage {
            id: Some(request_id.to_string()),
            message_type: McpMessageType::ToolResult,
            payload: serde_json::to_value(payload).unwrap(),
            version: self.version.clone(),
            metadata: HashMap::new(),
        }
    }

    /// Create error response message
    pub fn create_error_response(
        &self,
        request_id: Option<&str>,
        code: i32,
        message: &str,
        data: Option<serde_json::Value>,
    ) -> McpMessage {
        let payload = ErrorResponse {
            code,
            message: message.to_string(),
            data,
        };

        McpMessage {
            id: request_id.map(|s| s.to_string()),
            message_type: McpMessageType::Error,
            payload: serde_json::to_value(payload).unwrap(),
            version: self.version.clone(),
            metadata: HashMap::new(),
        }
    }

    /// Parse MCP message from JSON
    pub fn parse_message(&self, json: &str) -> Result<McpMessage> {
        serde_json::from_str(json).map_err(|e| McpToolsError::Serialization(e))
    }

    /// Serialize MCP message to JSON
    pub fn serialize_message(&self, message: &McpMessage) -> Result<String> {
        serde_json::to_string(message).map_err(|e| McpToolsError::Serialization(e))
    }

    /// Validate message protocol version
    pub fn validate_version(&self, message: &McpMessage) -> Result<()> {
        if message.version != self.version {
            return Err(McpToolsError::Server(format!(
                "Protocol version mismatch: expected {}, got {}",
                self.version, message.version
            )));
        }
        Ok(())
    }
}

impl Default for McpProtocol {
    fn default() -> Self {
        Self::new()
    }
}

/// Error codes for MCP protocol
pub mod error_codes {
    pub const PARSE_ERROR: i32 = -32700;
    pub const INVALID_REQUEST: i32 = -32600;
    pub const METHOD_NOT_FOUND: i32 = -32601;
    pub const INVALID_PARAMS: i32 = -32602;
    pub const INTERNAL_ERROR: i32 = -32603;

    // Custom error codes
    pub const PERMISSION_DENIED: i32 = -32000;
    pub const TOOL_NOT_FOUND: i32 = -32001;
    pub const TOOL_EXECUTION_ERROR: i32 = -32002;
    pub const SESSION_ERROR: i32 = -32003;
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_protocol_creation() {
        let protocol = McpProtocol::new();
        assert_eq!(protocol.version, MCP_PROTOCOL_VERSION);
    }

    #[test]
    fn test_message_id_generation() {
        let protocol = McpProtocol::new();
        let id1 = protocol.generate_message_id();
        let id2 = protocol.generate_message_id();

        assert_ne!(id1, id2);
        assert!(id1.starts_with("msg-"));
        assert!(id2.starts_with("msg-"));
    }

    #[test]
    fn test_initialize_request_creation() {
        let protocol = McpProtocol::new();
        let client_info = ClientInfo {
            name: "Test Client".to_string(),
            version: "1.0.0".to_string(),
            description: "Test".to_string(),
        };
        let capabilities = ClientCapabilities {
            content_types: vec!["text".to_string()],
            features: vec!["test".to_string()],
            info: client_info.clone(),
        };

        let message = protocol.create_initialize_request(client_info, capabilities);

        assert_eq!(message.message_type, McpMessageType::Initialize);
        assert!(message.id.is_some());
        assert_eq!(message.version, MCP_PROTOCOL_VERSION);
    }

    #[test]
    fn test_message_serialization() {
        let protocol = McpProtocol::new();
        let message = McpMessage {
            id: Some("test-id".to_string()),
            message_type: McpMessageType::Notification,
            payload: serde_json::json!({"test": "data"}),
            version: MCP_PROTOCOL_VERSION.to_string(),
            metadata: HashMap::new(),
        };

        let json = protocol.serialize_message(&message).unwrap();
        let parsed = protocol.parse_message(&json).unwrap();

        assert_eq!(parsed.id, message.id);
        assert_eq!(parsed.message_type, message.message_type);
        assert_eq!(parsed.version, message.version);
    }
}