use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt;
use super::{ClientCapabilities, ImplementationInfo, RequestId, ServerCapabilities};
use crate::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Message {
Request(Request),
Response(Response),
Notification(Notification),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Request {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
pub id: RequestId,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Response {
pub jsonrpc: String,
pub id: RequestId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<ResponseError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Notification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
pub mod error_codes {
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
pub const SERVER_NOT_INITIALIZED: i32 = -32002;
pub const UNKNOWN_ERROR_CODE: i32 = -32001;
pub const REQUEST_CANCELLED: i32 = -32800;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Method {
Initialize,
Initialized,
Shutdown,
Exit,
#[serde(rename = "notifications/cancelled")]
Cancel,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "$/progress")]
Progress,
#[serde(rename = "prompts/list")]
ListPrompts,
#[serde(rename = "prompts/get")]
GetPrompt,
#[serde(rename = "prompts/execute")]
ExecutePrompt,
#[serde(rename = "resources/list")]
ListResources,
#[serde(rename = "resources/get")]
GetResource,
#[serde(rename = "resources/create")]
CreateResource,
#[serde(rename = "resources/update")]
UpdateResource,
#[serde(rename = "resources/delete")]
DeleteResource,
#[serde(rename = "resources/subscribe")]
SubscribeResource,
#[serde(rename = "resources/unsubscribe")]
UnsubscribeResource,
#[serde(rename = "tools/list")]
ListTools,
#[serde(rename = "tools/get")]
GetTool,
#[serde(rename = "tools/execute")]
ExecuteTool,
#[serde(rename = "tools/cancel")]
CancelTool,
#[serde(rename = "roots/list")]
ListRoots,
#[serde(rename = "roots/get")]
GetRoot,
#[serde(rename = "sampling/request")]
SamplingRequest,
}
impl Request {
pub fn new(method: Method, params: Option<Value>, id: RequestId) -> Self {
Self {
jsonrpc: super::JSONRPC_VERSION.to_string(),
method: method.to_string(),
params,
id,
}
}
pub fn validate_id_uniqueness(&self, used_ids: &mut std::collections::HashSet<String>) -> bool {
let id_str = match &self.id {
RequestId::String(s) => s.clone(),
RequestId::Number(n) => n.to_string(),
};
used_ids.insert(id_str)
}
}
impl Response {
pub fn success(result: Value, id: RequestId) -> Self {
Self {
jsonrpc: super::JSONRPC_VERSION.to_string(),
id,
result: Some(result),
error: None,
}
}
pub fn error(error: ResponseError, id: RequestId) -> Self {
Self {
jsonrpc: super::JSONRPC_VERSION.to_string(),
id,
result: None,
error: Some(error),
}
}
}
impl Notification {
pub fn new(method: Method, params: Option<Value>) -> Self {
Self {
jsonrpc: super::JSONRPC_VERSION.to_string(),
method: method.to_string(),
params,
}
}
}
impl fmt::Display for Method {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Method::Initialize => write!(f, "initialize"),
Method::Initialized => write!(f, "initialized"),
Method::Shutdown => write!(f, "shutdown"),
Method::Exit => write!(f, "exit"),
Method::Cancel => write!(f, "notifications/cancelled"),
Method::Ping => write!(f, "ping"),
Method::Progress => write!(f, "$/progress"),
Method::ListPrompts => write!(f, "prompts/list"),
Method::GetPrompt => write!(f, "prompts/get"),
Method::ExecutePrompt => write!(f, "prompts/execute"),
Method::ListResources => write!(f, "resources/list"),
Method::GetResource => write!(f, "resources/get"),
Method::CreateResource => write!(f, "resources/create"),
Method::UpdateResource => write!(f, "resources/update"),
Method::DeleteResource => write!(f, "resources/delete"),
Method::SubscribeResource => write!(f, "resources/subscribe"),
Method::UnsubscribeResource => write!(f, "resources/unsubscribe"),
Method::ListTools => write!(f, "tools/list"),
Method::GetTool => write!(f, "tools/get"),
Method::ExecuteTool => write!(f, "tools/execute"),
Method::CancelTool => write!(f, "tools/cancel"),
Method::ListRoots => write!(f, "roots/list"),
Method::GetRoot => write!(f, "roots/get"),
Method::SamplingRequest => write!(f, "sampling/request"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::collections::HashSet;
#[test]
fn test_request_id_must_be_string_or_integer() {
let string_id = RequestId::String("test-id".to_string());
let request = Request::new(Method::Initialize, None, string_id.clone());
assert!(matches!(request.id, RequestId::String(_)));
let integer_id = RequestId::Number(42);
let request = Request::new(Method::Initialize, None, integer_id.clone());
assert!(matches!(request.id, RequestId::Number(_)));
}
#[test]
fn test_request_id_uniqueness() {
let mut used_ids = HashSet::new();
let id1 = RequestId::String("test-1".to_string());
let id2 = RequestId::String("test-1".to_string());
assert!(is_unique_id(&id1, &mut used_ids)); assert!(!is_unique_id(&id2, &mut used_ids));
let id3 = RequestId::Number(1);
let id4 = RequestId::Number(1);
assert!(is_unique_id(&id3, &mut used_ids)); assert!(!is_unique_id(&id4, &mut used_ids)); }
fn is_unique_id(id: &RequestId, used_ids: &mut HashSet<String>) -> bool {
let id_str = match id {
RequestId::String(s) => s.clone(),
RequestId::Number(n) => n.to_string(),
};
used_ids.insert(id_str)
}
#[test]
fn test_request_id_serialization() {
let string_id = RequestId::String("test-id".to_string());
let json = serde_json::to_string(&string_id).unwrap();
assert_eq!(json, r#""test-id""#);
let integer_id = RequestId::Number(42);
let json = serde_json::to_string(&integer_id).unwrap();
assert_eq!(json, "42");
}
#[test]
fn test_request_id_deserialization() {
let json = r#""test-id""#;
let id: RequestId = serde_json::from_str(json).unwrap();
assert!(matches!(id, RequestId::String(s) if s == "test-id"));
let json = "42";
let id: RequestId = serde_json::from_str(json).unwrap();
assert!(matches!(id, RequestId::Number(n) if n == 42));
let json = "null";
let result: std::result::Result<RequestId, serde_json::Error> = serde_json::from_str(json);
assert!(result.is_err());
}
#[test]
fn test_request_with_same_id_in_different_sessions() {
let id = RequestId::Number(1);
let mut session1_ids = HashSet::new();
assert!(is_unique_id(&id, &mut session1_ids));
let mut session2_ids = HashSet::new();
assert!(is_unique_id(&id, &mut session2_ids));
}
#[test]
fn test_response_must_match_request_id() {
let request_id = RequestId::Number(42);
let request = Request::new(Method::Initialize, None, request_id.clone());
let success_response = Response::success(json!({"result": "success"}), request_id.clone());
assert!(matches!(success_response.id, RequestId::Number(42)));
let error_response = Response::error(
ResponseError {
code: error_codes::INTERNAL_ERROR,
message: "error".to_string(),
data: None,
},
request_id.clone(),
);
assert!(matches!(error_response.id, RequestId::Number(42)));
let different_id = RequestId::Number(43);
let different_response = Response::success(json!({"result": "success"}), different_id);
assert!(matches!(different_response.id, RequestId::Number(43)));
}
#[test]
fn test_response_must_set_result_or_error_not_both() {
let id = RequestId::Number(1);
let success_response = Response::success(json!({"data": "success"}), id.clone());
assert!(success_response.result.is_some());
assert!(success_response.error.is_none());
let error_response = Response::error(
ResponseError {
code: error_codes::INTERNAL_ERROR,
message: "error".to_string(),
data: None,
},
id.clone(),
);
assert!(error_response.result.is_none());
assert!(error_response.error.is_some());
let success_json = serde_json::to_string(&success_response).unwrap();
assert!(!success_json.contains(r#""error""#));
let error_json = serde_json::to_string(&error_response).unwrap();
assert!(!error_json.contains(r#""result""#));
}
#[test]
fn test_error_code_must_be_integer() {
let id = RequestId::Number(1);
let standard_errors = [
error_codes::PARSE_ERROR,
error_codes::INVALID_REQUEST,
error_codes::METHOD_NOT_FOUND,
error_codes::INVALID_PARAMS,
error_codes::INTERNAL_ERROR,
error_codes::SERVER_NOT_INITIALIZED,
error_codes::UNKNOWN_ERROR_CODE,
error_codes::REQUEST_CANCELLED,
];
for &code in &standard_errors {
let error_response = Response::error(
ResponseError {
code,
message: "test error".to_string(),
data: None,
},
id.clone(),
);
if let Some(error) = error_response.error {
assert_eq!(
std::mem::size_of_val(&error.code),
std::mem::size_of::<i32>()
);
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains(&format!(r#"code":{}"#, error.code)));
} else {
panic!("Error field should be set");
}
}
let custom_codes = [-1, 0, 1, 1000, -1000];
for code in custom_codes {
let error_response = Response::error(
ResponseError {
code,
message: "custom error".to_string(),
data: None,
},
id.clone(),
);
if let Some(error) = error_response.error {
assert_eq!(error.code, code);
assert_eq!(
std::mem::size_of_val(&error.code),
std::mem::size_of::<i32>()
);
} else {
panic!("Error field should be set");
}
}
}
#[test]
fn test_notification_must_not_contain_id() {
let notification = Notification::new(Method::Initialized, Some(json!({"status": "ready"})));
let json_str = serde_json::to_string(¬ification).unwrap();
assert!(!json_str.contains(r#""id""#));
let json_without_id = r#"{
"jsonrpc": "2.0",
"method": "initialized",
"params": {"status": "ready"}
}"#;
let parsed: Message = serde_json::from_str(json_without_id).unwrap();
assert!(matches!(parsed, Message::Notification(_)));
let json_with_id = r#"{
"jsonrpc": "2.0",
"method": "initialized",
"params": {"status": "ready"},
"id": 1
}"#;
let parsed: Message = serde_json::from_str(json_with_id).unwrap();
assert!(matches!(parsed, Message::Request(_)));
assert!(!matches!(parsed, Message::Notification(_)));
}
#[test]
fn test_initialization_protocol_compliance() {
let request = Request::new(
Method::Initialize,
Some(json!({
"protocolVersion": super::super::PROTOCOL_VERSION,
"capabilities": {
"roots": {
"listChanged": true
},
"sampling": {}
},
"clientInfo": {
"name": "TestClient",
"version": "1.0.0"
}
})),
RequestId::Number(1),
);
let request_json = serde_json::to_string(&request).unwrap();
assert!(request_json.contains(r#""method":"initialize""#));
assert!(request_json.contains(super::super::PROTOCOL_VERSION));
assert!(request_json.contains(r#""capabilities""#));
assert!(request_json.contains(r#""clientInfo""#));
let response = Response::success(
json!({
"protocolVersion": super::super::PROTOCOL_VERSION,
"capabilities": {
"prompts": {
"listChanged": true
},
"resources": {
"subscribe": true,
"listChanged": true
},
"tools": {
"listChanged": true
},
"logging": {}
},
"serverInfo": {
"name": "TestServer",
"version": "1.0.0"
}
}),
RequestId::Number(1),
);
let response_json = serde_json::to_string(&response).unwrap();
assert!(response_json.contains(super::super::PROTOCOL_VERSION));
assert!(response_json.contains(r#""capabilities""#));
assert!(response_json.contains(r#""serverInfo""#));
let notification = Notification::new(Method::Initialized, None);
let notification_json = serde_json::to_string(¬ification).unwrap();
assert!(notification_json.contains(r#""method":"initialized""#));
assert!(!notification_json.contains(r#""id""#));
}
#[test]
fn test_initialization_version_negotiation() {
let client_request = Request::new(
Method::Initialize,
Some(json!({
"protocolVersion": super::super::PROTOCOL_VERSION
})),
RequestId::Number(1),
);
let server_response = Response::success(
json!({
"protocolVersion": super::super::PROTOCOL_VERSION
}),
RequestId::Number(1),
);
let client_version: String = serde_json::from_value(
client_request
.params
.unwrap()
.get("protocolVersion")
.unwrap()
.clone(),
)
.unwrap();
let server_version: String = serde_json::from_value(
server_response
.result
.unwrap()
.get("protocolVersion")
.unwrap()
.clone(),
)
.unwrap();
assert_eq!(client_version, server_version);
assert_eq!(client_version, super::super::PROTOCOL_VERSION);
let unsupported_version = "1.0.0";
let client_request = Request::new(
Method::Initialize,
Some(json!({
"protocolVersion": unsupported_version
})),
RequestId::Number(2),
);
let server_error = Response::error(
ResponseError {
code: error_codes::INVALID_REQUEST,
message: "Unsupported protocol version".to_string(),
data: Some(json!({
"supported": [super::super::PROTOCOL_VERSION],
"requested": unsupported_version
})),
},
RequestId::Number(2),
);
let error_json = serde_json::to_string(&server_error).unwrap();
assert!(error_json.contains("Unsupported protocol version"));
assert!(error_json.contains(super::super::PROTOCOL_VERSION));
assert!(error_json.contains(unsupported_version));
}
#[test]
fn test_ping_mechanism() {
let ping_request =
Request::new(Method::Ping, None, RequestId::String("ping-1".to_string()));
let request_json = serde_json::to_string(&ping_request).unwrap();
assert!(request_json.contains(r#""method":"ping""#));
assert!(request_json.contains(r#""id":"ping-1""#));
assert!(!request_json.contains("params"));
let ping_response = Response::success(json!({}), RequestId::String("ping-1".to_string()));
let response_json = serde_json::to_string(&ping_response).unwrap();
assert!(response_json.contains(r#""result":{}"#));
assert!(response_json.contains(r#""id":"ping-1""#));
assert!(!response_json.contains("error"));
let mut session_ids = HashSet::new();
assert!(ping_request.validate_id_uniqueness(&mut session_ids));
assert!(!ping_request.validate_id_uniqueness(&mut session_ids));
let mismatched_response =
Response::success(json!({}), RequestId::String("wrong-id".to_string()));
assert_ne!(ping_request.id, mismatched_response.id);
let timeout_error = Response::error(
ResponseError {
code: error_codes::REQUEST_CANCELLED,
message: "Ping timeout".to_string(),
data: None,
},
RequestId::String("ping-1".to_string()),
);
let error_json = serde_json::to_string(&timeout_error).unwrap();
assert!(error_json.contains("Ping timeout"));
assert!(error_json.contains(&error_codes::REQUEST_CANCELLED.to_string()));
}
#[test]
fn test_ping_pong_sequence() {
let mut session_ids = HashSet::new();
let ping_request = Request::new(
Method::Ping,
None,
RequestId::String("ping-seq-1".to_string()),
);
assert!(ping_request.validate_id_uniqueness(&mut session_ids));
let pong_response =
Response::success(json!({}), RequestId::String("ping-seq-1".to_string()));
assert_eq!(ping_request.id, pong_response.id);
assert!(pong_response.result.is_some());
assert!(pong_response.error.is_none());
let ping_request_2 = Request::new(
Method::Ping,
None,
RequestId::String("ping-seq-2".to_string()),
);
assert!(ping_request_2.validate_id_uniqueness(&mut session_ids));
assert_ne!(ping_request.id, ping_request_2.id);
}
}