use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(
dead_code,
reason = "Fields required by JSON-RPC protocol but not all are read"
)]
pub struct Request {
pub jsonrpc: String,
pub id: RequestId,
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(
dead_code,
reason = "Fields required by JSON-RPC protocol but not all are read"
)]
pub struct Notification {
pub jsonrpc: String,
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(untagged)]
pub enum RequestId {
Number(i64),
String(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Response {
pub jsonrpc: String,
pub id: RequestId,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<ResponseError>,
}
impl Response {
pub fn success(id: RequestId, result: impl Serialize) -> Result<Self, serde_json::Error> {
Ok(Self {
jsonrpc: "2.0".to_string(),
id,
result: Some(serde_json::to_value(result)?),
error: None,
})
}
pub fn error(id: RequestId, code: i64, message: impl Into<String>) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(ResponseError {
code,
message: message.into(),
data: None,
}),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseError {
pub code: i64,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
pub const METHOD_NOT_FOUND: i64 = -32601;
pub const INTERNAL_ERROR: i64 = -32603;
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(
dead_code,
reason = "Fields required by MCP protocol but not all are read"
)]
pub struct InitializeParams {
pub protocol_version: String,
pub capabilities: ClientCapabilities,
pub client_info: ClientInfo,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(
dead_code,
reason = "Fields required by MCP protocol but not all are read"
)]
pub struct ClientCapabilities {
#[serde(default)]
pub roots: Option<RootsCapability>,
#[serde(default)]
pub sampling: Option<Value>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RootsCapability {
#[serde(default)]
pub list_changed: bool,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ClientInfo {
pub name: String,
#[serde(default)]
pub version: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResult {
pub protocol_version: String,
pub capabilities: ServerCapabilities,
pub server_info: ServerInfo,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<ToolsCapability>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolsCapability {
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerInfo {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub version: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListToolsResult {
pub tools: Vec<Tool>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CallToolParams {
pub name: String,
#[serde(default)]
pub arguments: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolResult {
pub content: Vec<ToolContent>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum ToolContent {
Text {
text: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Root {
pub uri: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RootsListResult {
pub roots: Vec<Root>,
}
impl CallToolResult {
pub fn text(text: impl Into<String>) -> Self {
Self {
content: vec![ToolContent::Text { text: text.into() }],
is_error: None,
}
}
pub fn error(message: impl Into<String>) -> Self {
Self {
content: vec![ToolContent::Text {
text: message.into(),
}],
is_error: Some(true),
}
}
}
#[cfg(test)]
#[allow(
clippy::expect_used,
reason = "tests use expect for readable assertions"
)]
mod tests {
use super::*;
use anyhow::{Context, Result};
#[test]
fn test_deserialize_initialize_params() -> Result<()> {
let json = r#"{
"protocolVersion": "2024-11-05",
"capabilities": {
"roots": { "listChanged": true }
},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}"#;
let params: InitializeParams = serde_json::from_str(json)?;
assert_eq!(params.protocol_version, "2024-11-05");
assert_eq!(params.client_info.name, "test-client");
Ok(())
}
#[test]
fn test_serialize_initialize_result() -> Result<()> {
let result = InitializeResult {
protocol_version: "2024-11-05".to_string(),
capabilities: ServerCapabilities {
tools: Some(ToolsCapability { list_changed: None }),
},
server_info: ServerInfo {
name: "catenary".to_string(),
version: Some("0.1.0".to_string()),
},
instructions: None,
};
let json = serde_json::to_string(&result)?;
assert!(json.contains("protocolVersion"));
assert!(json.contains("catenary"));
Ok(())
}
#[test]
fn test_serialize_tool() -> Result<()> {
let tool = Tool {
name: "hover".to_string(),
description: Some("Get hover info".to_string()),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"file": { "type": "string" },
"line": { "type": "integer" },
"character": { "type": "integer" }
},
"required": ["file", "line", "character"]
}),
annotations: None,
};
let json = serde_json::to_string(&tool)?;
assert!(json.contains("inputSchema"));
assert!(json.contains("hover"));
Ok(())
}
#[test]
fn test_call_tool_result_text() -> Result<()> {
let result = CallToolResult::text("Hello, world!");
let json = serde_json::to_string(&result)?;
assert!(json.contains("Hello, world!"));
assert!(!json.contains("isError"));
Ok(())
}
#[test]
fn test_call_tool_result_error() -> Result<()> {
let result = CallToolResult::error("Something went wrong");
let json = serde_json::to_string(&result)?;
assert!(json.contains("isError"));
assert!(json.contains("true"));
Ok(())
}
#[test]
fn test_response_success() -> Result<()> {
let resp = Response::success(RequestId::Number(1), serde_json::json!({"ok": true}))?;
let json = serde_json::to_string(&resp)?;
assert!(json.contains("result"));
assert!(!json.contains("error"));
Ok(())
}
#[test]
fn test_response_error() -> Result<()> {
let resp = Response::error(RequestId::Number(1), METHOD_NOT_FOUND, "Unknown method");
let json = serde_json::to_string(&resp)?;
assert!(json.contains("error"));
assert!(json.contains("-32601"));
assert!(!json.contains("result"));
Ok(())
}
#[test]
fn test_serialize_request() -> Result<()> {
let req = Request {
jsonrpc: "2.0".to_string(),
id: RequestId::String("catenary-0".to_string()),
method: "roots/list".to_string(),
params: None,
};
let json = serde_json::to_string(&req)?;
assert!(json.contains("roots/list"));
assert!(json.contains("catenary-0"));
Ok(())
}
#[test]
fn test_none_params_omitted_not_null() -> Result<()> {
let req = Request {
jsonrpc: "2.0".to_string(),
id: RequestId::String("catenary-0".to_string()),
method: "roots/list".to_string(),
params: None,
};
let json = serde_json::to_string(&req)?;
assert!(
!json.contains("params"),
"Request with params: None must omit the field, got: {json}"
);
let notification = Notification {
jsonrpc: "2.0".to_string(),
method: "notifications/tools/list_changed".to_string(),
params: None,
};
let json = serde_json::to_string(¬ification)?;
assert!(
!json.contains("params"),
"Notification with params: None must omit the field, got: {json}"
);
Ok(())
}
#[test]
fn test_deserialize_response_success() -> Result<()> {
let json = r#"{
"jsonrpc": "2.0",
"id": 1,
"result": {"roots": []}
}"#;
let resp: Response = serde_json::from_str(json)?;
assert!(resp.result.is_some());
assert!(resp.error.is_none());
Ok(())
}
#[test]
fn test_deserialize_response_error() -> Result<()> {
let json = r#"{
"jsonrpc": "2.0",
"id": "catenary-0",
"error": {"code": -32601, "message": "not found"}
}"#;
let resp: Response = serde_json::from_str(json)?;
assert!(resp.result.is_none());
let err = resp.error.as_ref().context("missing error")?;
assert_eq!(err.code, METHOD_NOT_FOUND);
Ok(())
}
#[test]
fn test_deserialize_root_with_name() -> Result<()> {
let json = r#"{"uri": "file:///tmp/project", "name": "My Project"}"#;
let root: Root = serde_json::from_str(json)?;
assert_eq!(root.uri, "file:///tmp/project");
assert_eq!(root.name.as_deref(), Some("My Project"));
Ok(())
}
#[test]
fn test_deserialize_root_without_name() -> Result<()> {
let json = r#"{"uri": "file:///tmp/project"}"#;
let root: Root = serde_json::from_str(json)?;
assert_eq!(root.uri, "file:///tmp/project");
assert!(root.name.is_none());
Ok(())
}
#[test]
fn test_deserialize_roots_list_result() -> Result<()> {
let json = r#"{
"roots": [
{"uri": "file:///tmp/a", "name": "A"},
{"uri": "file:///tmp/b"}
]
}"#;
let result: RootsListResult = serde_json::from_str(json)?;
assert_eq!(result.roots.len(), 2);
assert_eq!(result.roots[0].uri, "file:///tmp/a");
assert_eq!(result.roots[1].name, None);
Ok(())
}
}