1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::fmt;
4
5use super::{ClientCapabilities, ImplementationInfo, RequestId, ServerCapabilities};
6use crate::Result;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(untagged)]
12pub enum Message {
13 Request(Request),
14 Response(Response),
15 Notification(Notification),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Request {
22 pub jsonrpc: String,
25 pub method: String,
28 #[serde(skip_serializing_if = "Option::is_none")]
31 pub params: Option<Value>,
32 pub id: RequestId,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Response {
41 pub jsonrpc: String,
44 pub id: RequestId,
47 #[serde(skip_serializing_if = "Option::is_none")]
50 pub result: Option<Value>,
51 #[serde(skip_serializing_if = "Option::is_none")]
54 pub error: Option<ResponseError>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct Notification {
61 pub jsonrpc: String,
64 pub method: String,
67 #[serde(skip_serializing_if = "Option::is_none")]
70 pub params: Option<Value>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ResponseError {
77 pub code: i32,
80 pub message: String,
83 #[serde(skip_serializing_if = "Option::is_none")]
86 pub data: Option<Value>,
87}
88
89pub mod error_codes {
92 pub const PARSE_ERROR: i32 = -32700;
93 pub const INVALID_REQUEST: i32 = -32600;
94 pub const METHOD_NOT_FOUND: i32 = -32601;
95 pub const INVALID_PARAMS: i32 = -32602;
96 pub const INTERNAL_ERROR: i32 = -32603;
97 pub const SERVER_NOT_INITIALIZED: i32 = -32002;
98 pub const UNKNOWN_ERROR_CODE: i32 = -32001;
99 pub const REQUEST_CANCELLED: i32 = -32800;
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
105#[serde(rename_all = "camelCase")]
106pub enum Method {
107 Initialize,
110 Initialized,
111 Shutdown,
112 Exit,
113
114 #[serde(rename = "notifications/cancelled")]
117 Cancel,
118 #[serde(rename = "ping")]
119 Ping,
120 #[serde(rename = "$/progress")]
121 Progress,
122
123 #[serde(rename = "prompts/list")]
126 ListPrompts,
127 #[serde(rename = "prompts/get")]
128 GetPrompt,
129 #[serde(rename = "prompts/execute")]
130 ExecutePrompt,
131
132 #[serde(rename = "resources/list")]
133 ListResources,
134 #[serde(rename = "resources/get")]
135 GetResource,
136 #[serde(rename = "resources/create")]
137 CreateResource,
138 #[serde(rename = "resources/update")]
139 UpdateResource,
140 #[serde(rename = "resources/delete")]
141 DeleteResource,
142 #[serde(rename = "resources/subscribe")]
143 SubscribeResource,
144 #[serde(rename = "resources/unsubscribe")]
145 UnsubscribeResource,
146
147 #[serde(rename = "tools/list")]
148 ListTools,
149 #[serde(rename = "tools/get")]
150 GetTool,
151 #[serde(rename = "tools/execute")]
152 ExecuteTool,
153 #[serde(rename = "tools/cancel")]
154 CancelTool,
155
156 #[serde(rename = "roots/list")]
159 ListRoots,
160 #[serde(rename = "roots/get")]
161 GetRoot,
162
163 #[serde(rename = "sampling/request")]
164 SamplingRequest,
165}
166
167impl Request {
168 pub fn new(method: Method, params: Option<Value>, id: RequestId) -> Self {
171 Self {
176 jsonrpc: super::JSONRPC_VERSION.to_string(),
177 method: method.to_string(),
178 params,
179 id,
180 }
181 }
182
183 pub fn validate_id_uniqueness(&self, used_ids: &mut std::collections::HashSet<String>) -> bool {
186 let id_str = match &self.id {
187 RequestId::String(s) => s.clone(),
188 RequestId::Number(n) => n.to_string(),
189 };
190 used_ids.insert(id_str)
191 }
192}
193
194impl Response {
195 pub fn success(result: Value, id: RequestId) -> Self {
198 Self {
199 jsonrpc: super::JSONRPC_VERSION.to_string(),
200 id,
201 result: Some(result),
202 error: None,
203 }
204 }
205
206 pub fn error(error: ResponseError, id: RequestId) -> Self {
209 Self {
210 jsonrpc: super::JSONRPC_VERSION.to_string(),
211 id,
212 result: None,
213 error: Some(error),
214 }
215 }
216}
217
218impl Notification {
219 pub fn new(method: Method, params: Option<Value>) -> Self {
222 Self {
223 jsonrpc: super::JSONRPC_VERSION.to_string(),
224 method: method.to_string(),
225 params,
226 }
227 }
228}
229
230impl fmt::Display for Method {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 match self {
233 Method::Initialize => write!(f, "initialize"),
234 Method::Initialized => write!(f, "initialized"),
235 Method::Shutdown => write!(f, "shutdown"),
236 Method::Exit => write!(f, "exit"),
237 Method::Cancel => write!(f, "notifications/cancelled"),
238 Method::Ping => write!(f, "ping"),
239 Method::Progress => write!(f, "$/progress"),
240 Method::ListPrompts => write!(f, "prompts/list"),
241 Method::GetPrompt => write!(f, "prompts/get"),
242 Method::ExecutePrompt => write!(f, "prompts/execute"),
243 Method::ListResources => write!(f, "resources/list"),
244 Method::GetResource => write!(f, "resources/get"),
245 Method::CreateResource => write!(f, "resources/create"),
246 Method::UpdateResource => write!(f, "resources/update"),
247 Method::DeleteResource => write!(f, "resources/delete"),
248 Method::SubscribeResource => write!(f, "resources/subscribe"),
249 Method::UnsubscribeResource => write!(f, "resources/unsubscribe"),
250 Method::ListTools => write!(f, "tools/list"),
251 Method::GetTool => write!(f, "tools/get"),
252 Method::ExecuteTool => write!(f, "tools/execute"),
253 Method::CancelTool => write!(f, "tools/cancel"),
254 Method::ListRoots => write!(f, "roots/list"),
255 Method::GetRoot => write!(f, "roots/get"),
256 Method::SamplingRequest => write!(f, "sampling/request"),
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use serde_json::json;
265 use std::collections::HashSet;
266
267 #[test]
268 fn test_request_id_must_be_string_or_integer() {
269 let string_id = RequestId::String("test-id".to_string());
272 let request = Request::new(Method::Initialize, None, string_id.clone());
273 assert!(matches!(request.id, RequestId::String(_)));
274
275 let integer_id = RequestId::Number(42);
278 let request = Request::new(Method::Initialize, None, integer_id.clone());
279 assert!(matches!(request.id, RequestId::Number(_)));
280 }
281
282 #[test]
283 fn test_request_id_uniqueness() {
284 let mut used_ids = HashSet::new();
285
286 let id1 = RequestId::String("test-1".to_string());
289 let id2 = RequestId::String("test-1".to_string());
290
291 assert!(is_unique_id(&id1, &mut used_ids)); assert!(!is_unique_id(&id2, &mut used_ids)); let id3 = RequestId::Number(1);
299 let id4 = RequestId::Number(1);
300
301 assert!(is_unique_id(&id3, &mut used_ids)); assert!(!is_unique_id(&id4, &mut used_ids)); }
306
307 fn is_unique_id(id: &RequestId, used_ids: &mut HashSet<String>) -> bool {
310 let id_str = match id {
311 RequestId::String(s) => s.clone(),
312 RequestId::Number(n) => n.to_string(),
313 };
314 used_ids.insert(id_str)
315 }
316
317 #[test]
318 fn test_request_id_serialization() {
319 let string_id = RequestId::String("test-id".to_string());
322 let json = serde_json::to_string(&string_id).unwrap();
323 assert_eq!(json, r#""test-id""#);
324
325 let integer_id = RequestId::Number(42);
328 let json = serde_json::to_string(&integer_id).unwrap();
329 assert_eq!(json, "42");
330 }
331
332 #[test]
333 fn test_request_id_deserialization() {
334 let json = r#""test-id""#;
337 let id: RequestId = serde_json::from_str(json).unwrap();
338 assert!(matches!(id, RequestId::String(s) if s == "test-id"));
339
340 let json = "42";
343 let id: RequestId = serde_json::from_str(json).unwrap();
344 assert!(matches!(id, RequestId::Number(n) if n == 42));
345
346 let json = "null";
349 let result: std::result::Result<RequestId, serde_json::Error> = serde_json::from_str(json);
350 assert!(result.is_err());
351 }
352
353 #[test]
354 fn test_request_with_same_id_in_different_sessions() {
355 let id = RequestId::Number(1);
358
359 let mut session1_ids = HashSet::new();
362 assert!(is_unique_id(&id, &mut session1_ids));
363
364 let mut session2_ids = HashSet::new();
367 assert!(is_unique_id(&id, &mut session2_ids));
368 }
369
370 #[test]
371 fn test_response_must_match_request_id() {
372 let request_id = RequestId::Number(42);
375 let request = Request::new(Method::Initialize, None, request_id.clone());
376
377 let success_response = Response::success(json!({"result": "success"}), request_id.clone());
380 assert!(matches!(success_response.id, RequestId::Number(42)));
381
382 let error_response = Response::error(
385 ResponseError {
386 code: error_codes::INTERNAL_ERROR,
387 message: "error".to_string(),
388 data: None,
389 },
390 request_id.clone(),
391 );
392 assert!(matches!(error_response.id, RequestId::Number(42)));
393
394 let different_id = RequestId::Number(43);
397 let different_response = Response::success(json!({"result": "success"}), different_id);
398 assert!(matches!(different_response.id, RequestId::Number(43)));
399 }
400
401 #[test]
402 fn test_response_must_set_result_or_error_not_both() {
403 let id = RequestId::Number(1);
404
405 let success_response = Response::success(json!({"data": "success"}), id.clone());
408 assert!(success_response.result.is_some());
409 assert!(success_response.error.is_none());
410
411 let error_response = Response::error(
414 ResponseError {
415 code: error_codes::INTERNAL_ERROR,
416 message: "error".to_string(),
417 data: None,
418 },
419 id.clone(),
420 );
421 assert!(error_response.result.is_none());
422 assert!(error_response.error.is_some());
423
424 let success_json = serde_json::to_string(&success_response).unwrap();
427 assert!(!success_json.contains(r#""error""#));
428
429 let error_json = serde_json::to_string(&error_response).unwrap();
430 assert!(!error_json.contains(r#""result""#));
431 }
432
433 #[test]
434 fn test_error_code_must_be_integer() {
435 let id = RequestId::Number(1);
436
437 let standard_errors = [
440 error_codes::PARSE_ERROR,
441 error_codes::INVALID_REQUEST,
442 error_codes::METHOD_NOT_FOUND,
443 error_codes::INVALID_PARAMS,
444 error_codes::INTERNAL_ERROR,
445 error_codes::SERVER_NOT_INITIALIZED,
446 error_codes::UNKNOWN_ERROR_CODE,
447 error_codes::REQUEST_CANCELLED,
448 ];
449
450 for &code in &standard_errors {
451 let error_response = Response::error(
452 ResponseError {
453 code,
454 message: "test error".to_string(),
455 data: None,
456 },
457 id.clone(),
458 );
459
460 if let Some(error) = error_response.error {
461 assert_eq!(
464 std::mem::size_of_val(&error.code),
465 std::mem::size_of::<i32>()
466 );
467
468 let json = serde_json::to_string(&error).unwrap();
471 assert!(json.contains(&format!(r#"code":{}"#, error.code)));
472 } else {
473 panic!("Error field should be set");
474 }
475 }
476
477 let custom_codes = [-1, 0, 1, 1000, -1000];
480 for code in custom_codes {
481 let error_response = Response::error(
482 ResponseError {
483 code,
484 message: "custom error".to_string(),
485 data: None,
486 },
487 id.clone(),
488 );
489
490 if let Some(error) = error_response.error {
491 assert_eq!(error.code, code);
492 assert_eq!(
495 std::mem::size_of_val(&error.code),
496 std::mem::size_of::<i32>()
497 );
498 } else {
499 panic!("Error field should be set");
500 }
501 }
502 }
503
504 #[test]
505 fn test_notification_must_not_contain_id() {
506 let notification = Notification::new(Method::Initialized, Some(json!({"status": "ready"})));
509
510 let json_str = serde_json::to_string(¬ification).unwrap();
513
514 assert!(!json_str.contains(r#""id""#));
517
518 let json_without_id = r#"{
521 "jsonrpc": "2.0",
522 "method": "initialized",
523 "params": {"status": "ready"}
524 }"#;
525 let parsed: Message = serde_json::from_str(json_without_id).unwrap();
526 assert!(matches!(parsed, Message::Notification(_)));
527
528 let json_with_id = r#"{
531 "jsonrpc": "2.0",
532 "method": "initialized",
533 "params": {"status": "ready"},
534 "id": 1
535 }"#;
536 let parsed: Message = serde_json::from_str(json_with_id).unwrap();
537 assert!(matches!(parsed, Message::Request(_)));
538 assert!(!matches!(parsed, Message::Notification(_)));
539 }
540
541 #[test]
542 fn test_initialization_protocol_compliance() {
543 let request = Request::new(
546 Method::Initialize,
547 Some(json!({
548 "protocolVersion": super::super::PROTOCOL_VERSION,
549 "capabilities": {
550 "roots": {
551 "listChanged": true
552 },
553 "sampling": {}
554 },
555 "clientInfo": {
556 "name": "TestClient",
557 "version": "1.0.0"
558 }
559 })),
560 RequestId::Number(1),
561 );
562
563 let request_json = serde_json::to_string(&request).unwrap();
564
565 assert!(request_json.contains(r#""method":"initialize""#));
568 assert!(request_json.contains(super::super::PROTOCOL_VERSION));
569 assert!(request_json.contains(r#""capabilities""#));
570 assert!(request_json.contains(r#""clientInfo""#));
571
572 let response = Response::success(
575 json!({
576 "protocolVersion": super::super::PROTOCOL_VERSION,
577 "capabilities": {
578 "prompts": {
579 "listChanged": true
580 },
581 "resources": {
582 "subscribe": true,
583 "listChanged": true
584 },
585 "tools": {
586 "listChanged": true
587 },
588 "logging": {}
589 },
590 "serverInfo": {
591 "name": "TestServer",
592 "version": "1.0.0"
593 }
594 }),
595 RequestId::Number(1),
596 );
597
598 let response_json = serde_json::to_string(&response).unwrap();
599
600 assert!(response_json.contains(super::super::PROTOCOL_VERSION));
603 assert!(response_json.contains(r#""capabilities""#));
604 assert!(response_json.contains(r#""serverInfo""#));
605
606 let notification = Notification::new(Method::Initialized, None);
609 let notification_json = serde_json::to_string(¬ification).unwrap();
610
611 assert!(notification_json.contains(r#""method":"initialized""#));
614 assert!(!notification_json.contains(r#""id""#));
615 }
616
617 #[test]
618 fn test_initialization_version_negotiation() {
619 let client_request = Request::new(
622 Method::Initialize,
623 Some(json!({
624 "protocolVersion": super::super::PROTOCOL_VERSION
625 })),
626 RequestId::Number(1),
627 );
628
629 let server_response = Response::success(
630 json!({
631 "protocolVersion": super::super::PROTOCOL_VERSION
632 }),
633 RequestId::Number(1),
634 );
635
636 let client_version: String = serde_json::from_value(
637 client_request
638 .params
639 .unwrap()
640 .get("protocolVersion")
641 .unwrap()
642 .clone(),
643 )
644 .unwrap();
645
646 let server_version: String = serde_json::from_value(
647 server_response
648 .result
649 .unwrap()
650 .get("protocolVersion")
651 .unwrap()
652 .clone(),
653 )
654 .unwrap();
655
656 assert_eq!(client_version, server_version);
659 assert_eq!(client_version, super::super::PROTOCOL_VERSION);
660
661 let unsupported_version = "1.0.0";
664 let client_request = Request::new(
665 Method::Initialize,
666 Some(json!({
667 "protocolVersion": unsupported_version
668 })),
669 RequestId::Number(2),
670 );
671
672 let server_error = Response::error(
673 ResponseError {
674 code: error_codes::INVALID_REQUEST,
675 message: "Unsupported protocol version".to_string(),
676 data: Some(json!({
677 "supported": [super::super::PROTOCOL_VERSION],
678 "requested": unsupported_version
679 })),
680 },
681 RequestId::Number(2),
682 );
683
684 let error_json = serde_json::to_string(&server_error).unwrap();
687 assert!(error_json.contains("Unsupported protocol version"));
688 assert!(error_json.contains(super::super::PROTOCOL_VERSION));
689 assert!(error_json.contains(unsupported_version));
690 }
691
692 #[test]
693 fn test_ping_mechanism() {
694 let ping_request =
697 Request::new(Method::Ping, None, RequestId::String("ping-1".to_string()));
698
699 let request_json = serde_json::to_string(&ping_request).unwrap();
702 assert!(request_json.contains(r#""method":"ping""#));
703 assert!(request_json.contains(r#""id":"ping-1""#));
704 assert!(!request_json.contains("params"));
705
706 let ping_response = Response::success(json!({}), RequestId::String("ping-1".to_string()));
709
710 let response_json = serde_json::to_string(&ping_response).unwrap();
713 assert!(response_json.contains(r#""result":{}"#));
714 assert!(response_json.contains(r#""id":"ping-1""#));
715 assert!(!response_json.contains("error"));
716
717 let mut session_ids = HashSet::new();
720 assert!(ping_request.validate_id_uniqueness(&mut session_ids));
721 assert!(!ping_request.validate_id_uniqueness(&mut session_ids));
722
723 let mismatched_response =
726 Response::success(json!({}), RequestId::String("wrong-id".to_string()));
727 assert_ne!(ping_request.id, mismatched_response.id);
728
729 let timeout_error = Response::error(
732 ResponseError {
733 code: error_codes::REQUEST_CANCELLED,
734 message: "Ping timeout".to_string(),
735 data: None,
736 },
737 RequestId::String("ping-1".to_string()),
738 );
739
740 let error_json = serde_json::to_string(&timeout_error).unwrap();
743 assert!(error_json.contains("Ping timeout"));
744 assert!(error_json.contains(&error_codes::REQUEST_CANCELLED.to_string()));
745 }
746
747 #[test]
748 fn test_ping_pong_sequence() {
749 let mut session_ids = HashSet::new();
752
753 let ping_request = Request::new(
756 Method::Ping,
757 None,
758 RequestId::String("ping-seq-1".to_string()),
759 );
760 assert!(ping_request.validate_id_uniqueness(&mut session_ids));
761
762 let pong_response =
765 Response::success(json!({}), RequestId::String("ping-seq-1".to_string()));
766
767 assert_eq!(ping_request.id, pong_response.id);
770 assert!(pong_response.result.is_some());
771 assert!(pong_response.error.is_none());
772
773 let ping_request_2 = Request::new(
776 Method::Ping,
777 None,
778 RequestId::String("ping-seq-2".to_string()),
779 );
780 assert!(ping_request_2.validate_id_uniqueness(&mut session_ids));
781
782 assert_ne!(ping_request.id, ping_request_2.id);
785 }
786}