mcp_client_rs/
protocol.rs

1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4use crate::error::{Error, ErrorCode};
5
6pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
7pub const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &[LATEST_PROTOCOL_VERSION, "2024-10-07"];
8pub const JSONRPC_VERSION: &str = "2.0";
9
10/// A unique identifier for a request
11#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
12#[serde(untagged)]
13pub enum RequestId {
14    String(String),
15    Number(i64),
16}
17
18/// Base JSON-RPC request structure
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Request {
21    pub jsonrpc: String,
22    pub method: String,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub params: Option<serde_json::Value>,
25    pub id: RequestId,
26}
27
28/// Base JSON-RPC notification structure
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct Notification {
31    pub jsonrpc: String,
32    pub method: String,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub params: Option<serde_json::Value>,
35}
36
37/// Base JSON-RPC response structure
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Response {
40    pub jsonrpc: String,
41    pub id: RequestId,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub result: Option<serde_json::Value>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub error: Option<ResponseError>,
46}
47
48/// JSON-RPC error object
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ResponseError {
51    pub code: i32,
52    pub message: String,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub data: Option<serde_json::Value>,
55}
56
57impl Request {
58    pub fn new(
59        method: impl Into<String>,
60        params: Option<serde_json::Value>,
61        id: RequestId,
62    ) -> Self {
63        Self {
64            jsonrpc: crate::JSONRPC_VERSION.to_string(),
65            method: method.into(),
66            params,
67            id,
68        }
69    }
70}
71
72impl Notification {
73    pub fn new(method: impl Into<String>, params: Option<serde_json::Value>) -> Self {
74        Self {
75            jsonrpc: crate::JSONRPC_VERSION.to_string(),
76            method: method.into(),
77            params,
78        }
79    }
80}
81
82impl Response {
83    pub fn success(id: RequestId, result: Option<serde_json::Value>) -> Self {
84        Self {
85            jsonrpc: crate::JSONRPC_VERSION.to_string(),
86            id,
87            result,
88            error: None,
89        }
90    }
91
92    pub fn error(id: RequestId, error: ResponseError) -> Self {
93        Self {
94            jsonrpc: crate::JSONRPC_VERSION.to_string(),
95            id,
96            result: None,
97            error: Some(error),
98        }
99    }
100}
101
102impl From<Error> for ResponseError {
103    fn from(err: Error) -> Self {
104        match err {
105            Error::Protocol {
106                code,
107                message,
108                data,
109            } => ResponseError {
110                code: code.into(),
111                message,
112                data,
113            },
114            Error::Transport(msg) => ResponseError {
115                code: ErrorCode::InternalError.into(),
116                message: format!("Transport error: {}", msg),
117                data: None,
118            },
119            Error::Serialization(err) => ResponseError {
120                code: ErrorCode::ParseError.into(),
121                message: err.to_string(),
122                data: None,
123            },
124            Error::Io(err) => ResponseError {
125                code: ErrorCode::InternalError.into(),
126                message: err.to_string(),
127                data: None,
128            },
129            Error::Other(msg) => ResponseError {
130                code: ErrorCode::InternalError.into(),
131                message: msg,
132                data: None,
133            },
134        }
135    }
136}
137
138impl fmt::Display for RequestId {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        match self {
141            RequestId::String(s) => write!(f, "{}", s),
142            RequestId::Number(n) => write!(f, "{}", n),
143        }
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use serde_json::json;
151
152    #[test]
153    fn test_request_creation() {
154        let id = RequestId::Number(1);
155        let params = Some(json!({"key": "value"}));
156        let request = Request::new("test_method", params.clone(), id.clone());
157
158        assert_eq!(request.jsonrpc, JSONRPC_VERSION);
159        assert_eq!(request.method, "test_method");
160        assert_eq!(request.params, params);
161        assert_eq!(request.id, id);
162    }
163
164    #[test]
165    fn test_notification_creation() {
166        let params = Some(json!({"event": "update"}));
167        let notification = Notification::new("test_event", params.clone());
168
169        assert_eq!(notification.jsonrpc, JSONRPC_VERSION);
170        assert_eq!(notification.method, "test_event");
171        assert_eq!(notification.params, params);
172    }
173
174    #[test]
175    fn test_response_success() {
176        let id = RequestId::String("test-1".to_string());
177        let result = Some(json!({"status": "ok"}));
178        let response = Response::success(id.clone(), result.clone());
179
180        assert_eq!(response.jsonrpc, JSONRPC_VERSION);
181        assert_eq!(response.id, id);
182        assert_eq!(response.result, result);
183        assert!(response.error.is_none());
184    }
185
186    #[test]
187    fn test_response_error() {
188        let id = RequestId::Number(123);
189        let error = ResponseError {
190            code: -32600,
191            message: "Invalid Request".to_string(),
192            data: Some(json!({"details": "missing method"})),
193        };
194        let response = Response::error(id.clone(), error.clone());
195
196        assert_eq!(response.jsonrpc, JSONRPC_VERSION);
197        assert_eq!(response.id, id);
198        assert!(response.result.is_none());
199
200        let response_error = response.error.unwrap();
201        assert_eq!(response_error.code, error.code);
202        assert_eq!(response_error.message, error.message);
203    }
204
205    #[test]
206    fn test_request_id_display() {
207        let num_id = RequestId::Number(42);
208        let str_id = RequestId::String("test-id".to_string());
209
210        assert_eq!(num_id.to_string(), "42");
211        assert_eq!(str_id.to_string(), "test-id");
212    }
213
214    #[test]
215    fn test_protocol_versions() {
216        assert!(SUPPORTED_PROTOCOL_VERSIONS.contains(&LATEST_PROTOCOL_VERSION));
217        assert_eq!(JSONRPC_VERSION, "2.0");
218    }
219}