use crate::core::platform::container::arsenal::{Armament, ArsenalError};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
#[doc(hidden)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MCPMessage {
Request(MCPRequest),
Response(MCPResponse),
Notification(MCPNotification),
}
#[doc(hidden)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPRequest {
pub jsonrpc: String,
pub id: Value,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
impl MCPRequest {
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id: Value::String(Uuid::new_v4().to_string()),
method: method.into(),
params,
}
}
}
#[doc(hidden)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPResponse {
pub jsonrpc: String,
pub id: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<MCPError>,
}
#[doc(hidden)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPError {
pub code: i64,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
impl MCPError {
pub const PARSE_ERROR: i64 = -32700;
pub const INVALID_REQUEST: i64 = -32600;
pub const METHOD_NOT_FOUND: i64 = -32601;
pub const INVALID_PARAMS: i64 = -32602;
pub const INTERNAL_ERROR: i64 = -32603;
pub fn new(code: i64, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
data: None,
}
}
pub fn with_data(code: i64, message: impl Into<String>, data: Value) -> Self {
Self {
code,
message: message.into(),
data: Some(data),
}
}
}
#[doc(hidden)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
#[doc(hidden)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub server_info: Option<ServerInfo>,
#[serde(default)]
pub tools: Vec<ToolInfo>,
#[serde(flatten)]
pub extensions: HashMap<String, Value>,
}
#[doc(hidden)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerInfo {
pub name: String,
pub version: String,
}
#[doc(hidden)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
pub description: String,
#[serde(rename = "inputSchema")]
pub input_schema: Value,
}
#[async_trait]
pub trait MCPTransport: Send + Sync {
async fn send(&mut self, message: &MCPMessage) -> Result<(), ArsenalError>;
async fn receive(&mut self) -> Result<MCPMessage, ArsenalError>;
}
#[doc(hidden)]
pub struct MCPClient {
transport: Arc<tokio::sync::Mutex<Box<dyn MCPTransport>>>,
#[allow(dead_code)]
capabilities: Option<MCPCapabilities>,
}
impl MCPClient {
pub fn new(transport: Box<dyn MCPTransport>) -> Self {
Self {
transport: Arc::new(tokio::sync::Mutex::new(transport)),
capabilities: None,
}
}
pub async fn discover_tools(&self) -> Result<Vec<Armament>, ArsenalError> {
let request = MCPRequest::new("tools/list", Some(serde_json::json!({})));
let response = self.send_request(request).await?;
let tools_array = response
.get("tools")
.and_then(|v| v.as_array())
.ok_or_else(|| {
ArsenalError::ProtocolError("Invalid tools/list response format".to_string())
})?;
let mut armaments = Vec::new();
for tool_value in tools_array {
let tool_info: ToolInfo = serde_json::from_value(tool_value.clone()).map_err(|e| {
ArsenalError::ProtocolError(format!("Failed to parse tool info: {}", e))
})?;
let required_params = tool_info
.input_schema
.get("required")
.and_then(|r| r.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
armaments.push(Armament {
name: tool_info.name,
description: tool_info.description,
parameters: tool_info.input_schema,
required_params,
});
}
Ok(armaments)
}
pub async fn invoke_tool(
&self,
tool_name: &str,
arguments: HashMap<String, Value>,
) -> Result<Value, ArsenalError> {
let params = serde_json::json!({
"name": tool_name,
"arguments": arguments,
});
let request = MCPRequest::new("tools/call", Some(params));
let response = self.send_request(request).await?;
response
.get("content")
.cloned()
.ok_or_else(|| ArsenalError::ProtocolError("Missing content in response".to_string()))
}
async fn send_request(&self, request: MCPRequest) -> Result<Value, ArsenalError> {
let mut transport = self.transport.lock().await;
transport.send(&MCPMessage::Request(request)).await?;
let response_msg = transport.receive().await?;
match response_msg {
MCPMessage::Response(response) => {
if let Some(error) = response.error {
return Err(ArsenalError::ProtocolError(format!(
"MCP error {}: {}",
error.code, error.message
)));
}
response.result.ok_or_else(|| {
ArsenalError::ProtocolError(
"Response missing both result and error".to_string(),
)
})
}
_ => Err(ArsenalError::ProtocolError(
"Expected response, got different message type".to_string(),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcp_request_creation() {
let request = MCPRequest::new("test/method", Some(serde_json::json!({"key": "value"})));
assert_eq!(request.jsonrpc, "2.0");
assert_eq!(request.method, "test/method");
assert!(request.params.is_some());
}
#[test]
fn test_mcp_error_codes() {
assert_eq!(MCPError::PARSE_ERROR, -32700);
assert_eq!(MCPError::INVALID_REQUEST, -32600);
assert_eq!(MCPError::METHOD_NOT_FOUND, -32601);
}
}