1use serde::{Deserialize, Serialize};
28use serde_json::Value;
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
35#[serde(untagged)]
36pub enum RequestId {
37 String(String),
38 Integer(i64),
39}
40
41impl std::fmt::Display for RequestId {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 RequestId::String(s) => write!(f, "{}", s),
45 RequestId::Integer(i) => write!(f, "{}", i),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct JsonRpcRequest {
53 pub id: RequestId,
54 pub method: String,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
56 pub params: Option<Value>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct JsonRpcNotification {
62 pub method: String,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub params: Option<Value>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct JsonRpcResponse {
70 pub id: RequestId,
71 pub result: Value,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct JsonRpcErrorData {
77 pub code: i64,
78 pub message: String,
79 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub data: Option<Value>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct JsonRpcError {
86 pub error: JsonRpcErrorData,
87 pub id: RequestId,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
101#[serde(untagged)]
102pub enum JsonRpcMessage {
103 Request(JsonRpcRequest),
104 Response(JsonRpcResponse),
105 Error(JsonRpcError),
106 Notification(JsonRpcNotification),
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn test_request_id_string() {
115 let id: RequestId = serde_json::from_str(r#""req_1""#).unwrap();
116 assert_eq!(id, RequestId::String("req_1".to_string()));
117 assert_eq!(id.to_string(), "req_1");
118 }
119
120 #[test]
121 fn test_request_id_integer() {
122 let id: RequestId = serde_json::from_str("42").unwrap();
123 assert_eq!(id, RequestId::Integer(42));
124 assert_eq!(id.to_string(), "42");
125 }
126
127 #[test]
128 fn test_request_roundtrip() {
129 let req = JsonRpcRequest {
130 id: RequestId::Integer(1),
131 method: "thread/start".to_string(),
132 params: Some(serde_json::json!({"instructions": "hello"})),
133 };
134 let json = serde_json::to_string(&req).unwrap();
135 let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap();
136 assert_eq!(parsed.id, RequestId::Integer(1));
137 assert_eq!(parsed.method, "thread/start");
138 }
139
140 #[test]
141 fn test_request_no_params() {
142 let json = r#"{"id":1,"method":"turn/interrupt"}"#;
143 let req: JsonRpcRequest = serde_json::from_str(json).unwrap();
144 assert!(req.params.is_none());
145
146 let out = serde_json::to_string(&req).unwrap();
148 assert!(!out.contains("params"));
149 }
150
151 #[test]
152 fn test_notification_roundtrip() {
153 let notif = JsonRpcNotification {
154 method: "turn/started".to_string(),
155 params: Some(serde_json::json!({"threadId": "th_1", "turnId": "t_1"})),
156 };
157 let json = serde_json::to_string(¬if).unwrap();
158 let parsed: JsonRpcNotification = serde_json::from_str(&json).unwrap();
159 assert_eq!(parsed.method, "turn/started");
160 }
161
162 #[test]
163 fn test_response_roundtrip() {
164 let resp = JsonRpcResponse {
165 id: RequestId::Integer(1),
166 result: serde_json::json!({"threadId": "th_abc"}),
167 };
168 let json = serde_json::to_string(&resp).unwrap();
169 let parsed: JsonRpcResponse = serde_json::from_str(&json).unwrap();
170 assert_eq!(parsed.id, RequestId::Integer(1));
171 }
172
173 #[test]
174 fn test_error_roundtrip() {
175 let err = JsonRpcError {
176 id: RequestId::Integer(1),
177 error: JsonRpcErrorData {
178 code: -32600,
179 message: "Invalid request".to_string(),
180 data: None,
181 },
182 };
183 let json = serde_json::to_string(&err).unwrap();
184 let parsed: JsonRpcError = serde_json::from_str(&json).unwrap();
185 assert_eq!(parsed.error.code, -32600);
186 }
187
188 #[test]
189 fn test_message_dispatch_request() {
190 let json = r#"{"id":1,"method":"thread/start","params":{}}"#;
191 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
192 assert!(matches!(msg, JsonRpcMessage::Request(_)));
193 }
194
195 #[test]
196 fn test_message_dispatch_response() {
197 let json = r#"{"id":1,"result":{"threadId":"th_1"}}"#;
198 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
199 assert!(matches!(msg, JsonRpcMessage::Response(_)));
200 }
201
202 #[test]
203 fn test_message_dispatch_error() {
204 let json = r#"{"id":1,"error":{"code":-32600,"message":"bad"}}"#;
205 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
206 assert!(matches!(msg, JsonRpcMessage::Error(_)));
207 }
208
209 #[test]
210 fn test_message_dispatch_notification() {
211 let json = r#"{"method":"turn/started","params":{"threadId":"th_1"}}"#;
212 let msg: JsonRpcMessage = serde_json::from_str(json).unwrap();
213 assert!(matches!(msg, JsonRpcMessage::Notification(_)));
214 }
215
216 #[test]
217 fn test_no_jsonrpc_field() {
218 let req = JsonRpcRequest {
220 id: RequestId::Integer(1),
221 method: "test".to_string(),
222 params: None,
223 };
224 let json = serde_json::to_string(&req).unwrap();
225 assert!(!json.contains("jsonrpc"));
226 }
227}