1use crate::RpcId;
2use crate::router::{CallError, CallResult, CallSuccess};
3use crate::rpc_response::{RpcError, RpcResponseParsingError};
4use serde::de::{MapAccess, Visitor};
5use serde::ser::SerializeMap;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use serde_json::Value;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
14pub enum RpcResponse {
15 Success(RpcSuccessResponse),
17 Error(RpcErrorResponse),
19}
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub struct RpcSuccessResponse {
24 pub id: RpcId,
26
27 pub result: Value,
29}
30
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33pub struct RpcErrorResponse {
34 pub id: RpcId,
36
37 pub error: RpcError,
39}
40
41impl RpcResponse {
44 pub fn from_success(id: RpcId, result: Value) -> Self {
45 Self::Success(RpcSuccessResponse { id, result })
46 }
47
48 pub fn from_error(id: RpcId, error: RpcError) -> Self {
49 Self::Error(RpcErrorResponse { id, error })
50 }
51}
52
53impl RpcResponse {
58 pub fn is_success(&self) -> bool {
59 matches!(self, RpcResponse::Success(_))
60 }
61
62 pub fn is_error(&self) -> bool {
63 matches!(self, RpcResponse::Error(_))
64 }
65
66 pub fn id(&self) -> &RpcId {
67 match self {
68 RpcResponse::Success(r) => &r.id,
69 RpcResponse::Error(r) => &r.id,
70 }
71 }
72
73 pub fn into_parts(self) -> (RpcId, core::result::Result<Value, RpcError>) {
76 match self {
77 RpcResponse::Success(r) => (r.id, Ok(r.result)),
78 RpcResponse::Error(r) => (r.id, Err(r.error)),
79 }
80 }
81}
82impl From<CallSuccess> for RpcResponse {
87 fn from(call_success: CallSuccess) -> Self {
89 RpcResponse::from_success(call_success.id, call_success.value)
90 }
91}
92
93impl From<CallError> for RpcResponse {
94 fn from(call_error: CallError) -> Self {
96 let id = call_error.id.clone(); let error = RpcError::from(call_error); RpcResponse::from_error(id, error)
99 }
100}
101
102impl From<CallResult> for RpcResponse {
103 fn from(call_result: CallResult) -> Self {
106 match call_result {
107 Ok(call_success) => RpcResponse::from(call_success),
108 Err(call_error) => RpcResponse::from(call_error),
109 }
110 }
111}
112
113impl Serialize for RpcResponse {
118 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
119 where
120 S: Serializer,
121 {
122 let mut map = serializer.serialize_map(Some(3))?;
123 map.serialize_entry("jsonrpc", "2.0")?;
124
125 match self {
126 RpcResponse::Success(RpcSuccessResponse { id, result }) => {
127 map.serialize_entry("id", id)?;
128 map.serialize_entry("result", result)?;
129 }
130 RpcResponse::Error(RpcErrorResponse { id, error }) => {
131 map.serialize_entry("id", id)?;
132 map.serialize_entry("error", error)?;
133 }
134 }
135
136 map.end()
137 }
138}
139
140impl<'de> Deserialize<'de> for RpcResponse {
141 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
142 where
143 D: Deserializer<'de>,
144 {
145 struct RpcResponseVisitor;
146
147 impl<'de> Visitor<'de> for RpcResponseVisitor {
148 type Value = RpcResponse;
149
150 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
151 formatter.write_str("a JSON-RPC 2.0 response object")
152 }
153
154 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
155 where
156 A: MapAccess<'de>,
157 {
158 let mut version: Option<String> = None;
159 let mut id_val: Option<Value> = None;
160 let mut result_val: Option<Value> = None;
161 let mut error_val: Option<Value> = None;
162
163 while let Some(key) = map.next_key::<String>()? {
164 match key.as_str() {
165 "jsonrpc" => {
166 if version.is_some() {
167 return Err(serde::de::Error::duplicate_field("jsonrpc"));
168 }
169 version = Some(map.next_value()?);
170 }
171 "id" => {
172 if id_val.is_some() {
173 return Err(serde::de::Error::duplicate_field("id"));
174 }
175 id_val = Some(map.next_value()?);
177 }
178 "result" => {
179 if result_val.is_some() {
180 return Err(serde::de::Error::duplicate_field("result"));
181 }
182 result_val = Some(map.next_value()?);
183 }
184 "error" => {
185 if error_val.is_some() {
186 return Err(serde::de::Error::duplicate_field("error"));
187 }
188 error_val = Some(map.next_value()?);
190 }
191 _ => {
192 let _: Value = map.next_value()?;
194 }
195 }
196 }
197
198 let id_for_error = id_val.as_ref().and_then(|v| RpcId::from_value(v.clone()).ok());
200 match version.as_deref() {
201 Some("2.0") => {} Some(v) => {
203 return Err(serde::de::Error::custom(
204 RpcResponseParsingError::InvalidJsonRpcVersion {
205 id: id_for_error,
206 expected: "2.0",
207 actual: Some(Value::String(v.to_string())),
208 },
209 ));
210 }
211 None => {
212 return Err(serde::de::Error::custom(
213 RpcResponseParsingError::MissingJsonRpcVersion { id: id_for_error },
214 ));
215 }
216 };
217
218 let id = match id_val {
220 Some(v) => RpcId::from_value(v)
221 .map_err(|e| serde::de::Error::custom(RpcResponseParsingError::InvalidId(e)))?,
222 None => return Err(serde::de::Error::custom(RpcResponseParsingError::MissingId)),
223 };
224
225 match (result_val, error_val) {
227 (Some(result), None) => Ok(RpcResponse::Success(RpcSuccessResponse { id, result })),
228 (None, Some(error_value)) => {
229 let error: RpcError = serde_json::from_value(error_value)
231 .map_err(|e| serde::de::Error::custom(RpcResponseParsingError::InvalidErrorObject(e)))?;
232 Ok(RpcResponse::Error(RpcErrorResponse { id, error }))
233 }
234 (Some(_), Some(_)) => Err(serde::de::Error::custom(RpcResponseParsingError::BothResultAndError {
235 id: id.clone(),
236 })),
237 (None, None) => Err(serde::de::Error::custom(
238 RpcResponseParsingError::MissingResultAndError { id: id.clone() },
239 )),
240 }
241 }
242 }
243
244 deserializer.deserialize_map(RpcResponseVisitor)
245 }
246}
247
248#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::Error as RouterError; use serde_json::{from_value, json, to_value};
257
258 type TestResult<T> = core::result::Result<T, Box<dyn std::error::Error>>; fn create_call_error(id: impl Into<RpcId>, method: &str, error: RouterError) -> CallError {
262 CallError {
263 id: id.into(),
264 method: method.to_string(),
265 error,
266 }
267 }
268
269 #[test]
270 fn test_rpc_response_success_ser_de() -> TestResult<()> {
271 let id = RpcId::Number(1);
273 let result_val = json!({"data": "ok"});
274 let response = RpcResponse::from_success(id.clone(), result_val.clone());
275 let expected_json = json!({
276 "jsonrpc": "2.0",
277 "id": 1,
278 "result": {"data": "ok"}
279 });
280
281 let serialized_value = to_value(&response)?;
283
284 assert_eq!(serialized_value, expected_json);
286
287 let deserialized_response: RpcResponse = from_value(serialized_value)?;
289
290 assert_eq!(deserialized_response, response);
292 assert_eq!(deserialized_response.id(), &id);
293 assert!(deserialized_response.is_success());
294 assert!(!deserialized_response.is_error());
295 let (resp_id, resp_result) = deserialized_response.into_parts();
296 assert_eq!(resp_id, id);
297 assert_eq!(resp_result.unwrap(), result_val);
298
299 Ok(())
300 }
301
302 #[test]
303 fn test_rpc_response_error_ser_de() -> TestResult<()> {
304 let id = RpcId::String("req-abc".into());
306 let rpc_error = RpcError {
307 code: -32601,
308 message: "Method not found".to_string(),
309 data: Some(json!("method_name")),
310 };
311 let response = RpcResponse::from_error(id.clone(), rpc_error.clone());
312 let expected_json = json!({
313 "jsonrpc": "2.0",
314 "id": "req-abc",
315 "error": {
316 "code": -32601,
317 "message": "Method not found",
318 "data": "method_name"
319 }
320 });
321
322 let serialized_value = to_value(&response)?;
324
325 assert_eq!(serialized_value, expected_json);
327
328 let deserialized_response: RpcResponse = from_value(serialized_value)?;
330
331 assert_eq!(deserialized_response, response);
333 assert_eq!(deserialized_response.id(), &id);
334 assert!(!deserialized_response.is_success());
335 assert!(deserialized_response.is_error());
336 let (resp_id, resp_result) = deserialized_response.into_parts();
337 assert_eq!(resp_id, id);
338 assert_eq!(resp_result.unwrap_err(), rpc_error);
339
340 Ok(())
341 }
342
343 #[test]
344 fn test_rpc_response_error_ser_de_no_data() -> TestResult<()> {
345 let id = RpcId::Null;
347 let rpc_error = RpcError {
348 code: -32700,
349 message: "Parse error".to_string(),
350 data: None, };
352 let response = RpcResponse::from_error(id.clone(), rpc_error.clone());
353 let expected_json = json!({
354 "jsonrpc": "2.0",
355 "id": null,
356 "error": {
357 "code": -32700,
358 "message": "Parse error"
359 }
361 });
362
363 let serialized_value = to_value(&response)?;
365
366 assert_eq!(serialized_value, expected_json);
368
369 let deserialized_response: RpcResponse = from_value(serialized_value)?;
371
372 assert_eq!(deserialized_response, response);
374 assert_eq!(deserialized_response.id(), &id);
375 assert!(deserialized_response.is_error());
376 let (resp_id, resp_result) = deserialized_response.into_parts();
377 assert_eq!(resp_id, id);
378 assert_eq!(resp_result.unwrap_err(), rpc_error);
379
380 Ok(())
381 }
382
383 #[test]
384 fn test_rpc_response_de_invalid() {
385 let invalid_jsons = vec![
387 json!({"id": 1, "result": "ok"}),
389 json!({"jsonrpc": "1.0", "id": 1, "result": "ok"}),
391 json!({"jsonrpc": "2.0", "result": "ok"}),
393 json!({"jsonrpc": "2.0", "id": 1}),
395 json!({"jsonrpc": "2.0", "id": 1, "result": "ok", "error": {"code": 1, "message": "err"}}),
397 json!({"jsonrpc": "2.0", "id": 1, "error": "not an object"}),
399 json!({"jsonrpc": "2.0", "id": 1, "error": {"message": "err"}}),
401 json!({"jsonrpc": "2.0", "id": 1, "error": {"code": 1}}),
403 json!({"jsonrpc": "2.0", "id": [1,2], "result": "ok"}),
405 ];
406
407 for json_value in invalid_jsons {
409 let result: Result<RpcResponse, _> = from_value(json_value.clone());
410 assert!(result.is_err(), "Expected error for invalid JSON: {}", json_value);
411 }
412 }
413
414 #[test]
416 fn test_from_call_success() -> TestResult<()> {
417 let call_success = CallSuccess {
419 id: RpcId::Number(101),
420 method: "test_method".to_string(),
421 value: json!({"success": true}),
422 };
423
424 let rpc_response = RpcResponse::from(call_success);
426
427 match rpc_response {
429 RpcResponse::Success(RpcSuccessResponse { id, result }) => {
430 assert_eq!(id, RpcId::Number(101));
431 assert_eq!(result, json!({"success": true}));
432 }
433 RpcResponse::Error(_) => panic!("Expected RpcResponse::Success"),
434 }
435 Ok(())
436 }
437
438 #[test]
439 fn test_from_call_error() -> TestResult<()> {
440 let call_error = create_call_error(102, "test_method", RouterError::MethodUnknown);
442
443 let rpc_response = RpcResponse::from(call_error);
445
446 match rpc_response {
448 RpcResponse::Error(RpcErrorResponse { id, error }) => {
449 assert_eq!(id, RpcId::Number(102));
450 assert_eq!(error.code, RpcError::CODE_METHOD_NOT_FOUND);
451 assert_eq!(error.message, "Method not found");
452 assert!(error.data.is_some()); }
454 RpcResponse::Success(_) => panic!("Expected RpcResponse::Error"),
455 }
456 Ok(())
457 }
458
459 #[test]
460 fn test_from_call_result_ok() -> TestResult<()> {
461 let call_result: CallResult = Ok(CallSuccess {
463 id: 103.into(),
464 method: "test_method".to_string(),
465 value: json!("ok_data"),
466 });
467
468 let rpc_response = RpcResponse::from(call_result);
470
471 match rpc_response {
473 RpcResponse::Success(RpcSuccessResponse { id, result }) => {
474 assert_eq!(id, RpcId::Number(103));
475 assert_eq!(result, json!("ok_data"));
476 }
477 RpcResponse::Error(_) => panic!("Expected RpcResponse::Success"),
478 }
479 Ok(())
480 }
481
482 #[test]
483 fn test_from_call_result_err() -> TestResult<()> {
484 let call_result: CallResult = Err(create_call_error(
486 "err-104",
487 "test_method",
488 RouterError::ParamsMissingButRequested,
489 ));
490
491 let rpc_response = RpcResponse::from(call_result);
493
494 match rpc_response {
496 RpcResponse::Error(RpcErrorResponse { id, error }) => {
497 assert_eq!(id, RpcId::String("err-104".into()));
498 assert_eq!(error.code, RpcError::CODE_INVALID_PARAMS);
499 assert_eq!(error.message, "Invalid params");
500 assert!(error.data.is_some()); }
502 RpcResponse::Success(_) => panic!("Expected RpcResponse::Error"),
503 }
504 Ok(())
505 }
506 }
508
509