use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, Ordering};
static REQUEST_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
pub fn next_request_id() -> u64 {
REQUEST_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: u64,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
impl JsonRpcRequest {
pub fn new(method: impl Into<String>, params: Option<serde_json::Value>) -> Self {
Self {
jsonrpc: "2.0".into(),
id: next_request_id(),
method: method.into(),
params,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i64,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientInfo {
pub name: String,
pub version: String,
}
impl Default for ClientInfo {
fn default() -> Self {
Self {
name: "yoagent".into(),
version: env!("CARGO_PKG_VERSION").into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerInfo {
pub name: String,
#[serde(default)]
pub version: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ServerCapabilities {
#[serde(default)]
pub tools: Option<serde_json::Value>,
#[serde(default)]
pub resources: Option<serde_json::Value>,
#[serde(default)]
pub prompts: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResult {
pub protocol_version: String,
pub capabilities: ServerCapabilities,
pub server_info: ServerInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct McpToolInfo {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub input_schema: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolsListResult {
pub tools: Vec<McpToolInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum McpContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image {
data: String,
#[serde(rename = "mimeType")]
mime_type: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct McpToolCallResult {
pub content: Vec<McpContent>,
#[serde(default)]
pub is_error: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum McpError {
#[error("Transport error: {0}")]
Transport(String),
#[error("Protocol error: {0}")]
Protocol(String),
#[error("JSON-RPC error {code}: {message}")]
JsonRpc { code: i64, message: String },
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Connection closed")]
ConnectionClosed,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_rpc_request_serialization() {
let req = JsonRpcRequest {
jsonrpc: "2.0".into(),
id: 1,
method: "initialize".into(),
params: Some(serde_json::json!({"protocolVersion": "2024-11-05"})),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"jsonrpc\":\"2.0\""));
assert!(json.contains("\"method\":\"initialize\""));
let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, 1);
assert_eq!(parsed.method, "initialize");
}
#[test]
fn test_json_rpc_response_deserialization() {
let json = r#"{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{},"serverInfo":{"name":"test","version":"1.0"}}}"#;
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.id, Some(1));
assert!(resp.result.is_some());
assert!(resp.error.is_none());
}
#[test]
fn test_json_rpc_error_response() {
let json =
r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"Method not found"}}"#;
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
assert!(resp.error.is_some());
let err = resp.error.unwrap();
assert_eq!(err.code, -32601);
}
#[test]
fn test_initialize_result_deserialization() {
let json = r#"{"protocolVersion":"2024-11-05","capabilities":{"tools":{}},"serverInfo":{"name":"test-server","version":"0.1.0"}}"#;
let result: InitializeResult = serde_json::from_str(json).unwrap();
assert_eq!(result.server_info.name, "test-server");
assert!(result.capabilities.tools.is_some());
}
#[test]
fn test_mcp_tool_info_deserialization() {
let json = r#"{"name":"read_file","description":"Read a file","inputSchema":{"type":"object","properties":{"path":{"type":"string"}}}}"#;
let tool: McpToolInfo = serde_json::from_str(json).unwrap();
assert_eq!(tool.name, "read_file");
assert_eq!(tool.description.as_deref(), Some("Read a file"));
}
#[test]
fn test_mcp_tool_call_result() {
let json = r#"{"content":[{"type":"text","text":"file contents here"}],"isError":false}"#;
let result: McpToolCallResult = serde_json::from_str(json).unwrap();
assert_eq!(result.content.len(), 1);
assert!(!result.is_error);
}
#[test]
fn test_unique_request_ids() {
let id1 = next_request_id();
let id2 = next_request_id();
assert_ne!(id1, id2);
}
}