1use std::fmt::{Display, Formatter};
2
3use completeq_rs::error::CompleteQError;
4use futures::channel::mpsc::SendError;
5use serde::*;
6
7#[derive(Debug, Serialize, Deserialize, Default, PartialEq)]
11pub struct Request<S, P>
12where
13 S: AsRef<str>,
14{
15 #[serde(skip_serializing_if = "Option::is_none")]
19 pub id: Option<usize>,
20 pub jsonrpc: Version,
22 pub method: S,
26 pub params: P,
28}
29
30#[derive(Debug, Default, PartialEq)]
34pub struct Version;
35
36impl Serialize for Version {
37 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
38 where
39 S: Serializer,
40 {
41 serializer.serialize_str("2.0")
42 }
43}
44
45impl<'de> Deserialize<'de> for Version {
46 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47 where
48 D: Deserializer<'de>,
49 {
50 deserializer.deserialize_any(visitor::VersionVisitor)
51 }
52}
53
54#[derive(Debug, Serialize, Deserialize, Default, PartialEq)]
59pub struct Response<S, R, D>
60where
61 S: AsRef<str>,
62{
63 pub id: usize,
66 pub jsonrpc: Version,
68 #[serde(skip_serializing_if = "Option::is_none")]
72 pub result: Option<R>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
77 pub error: Option<Error<S, D>>,
78}
79
80#[derive(Debug, Serialize, Deserialize, Default, PartialEq)]
82struct JSONRPC<S, P, R, D> {
83 #[serde(skip_serializing_if = "Option::is_none")]
87 pub id: Option<usize>,
88 pub jsonrpc: Version,
90 #[serde(skip_serializing_if = "Option::is_none")]
94 pub method: Option<S>,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub params: Option<P>,
98 #[serde(skip_serializing_if = "Option::is_none")]
102 pub result: Option<R>,
103
104 #[serde(skip_serializing_if = "Option::is_none")]
107 pub error: Option<Error<S, D>>,
108}
109
110#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, thiserror::Error)]
113pub struct Error<S, D> {
114 pub code: ErrorCode,
116 pub message: S,
119 pub data: Option<D>,
124}
125
126impl Display for Error<String, serde_json::Value> {
127 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
128 write!(f, "RPCError({}) {}", self.code, self.message)
129 }
130}
131
132impl From<serde_json::Error> for Error<String, serde_json::Value> {
133 fn from(err: serde_json::Error) -> Self {
134 Self {
135 code: ErrorCode::ParseError,
136 message: format!("Serialize/Deserialize json data error: {}", err),
137 data: None,
138 }
139 }
140}
141
142impl From<CompleteQError> for Error<String, serde_json::Value> {
143 fn from(err: CompleteQError) -> Self {
144 Self {
145 code: ErrorCode::InternalError,
146 message: format!("RPC call channel broken: {}", err),
147 data: None,
148 }
149 }
150}
151
152impl From<SendError> for Error<String, serde_json::Value> {
153 fn from(err: SendError) -> Self {
154 Self {
155 code: ErrorCode::InternalError,
156 message: format!("RPC send channel broken: {}", err),
157 data: None,
158 }
159 }
160}
161
162impl Error<String, serde_json::Value> {
163 pub fn from_std_error<E>(e: E) -> Self
164 where
165 E: Display,
166 {
167 Self {
168 code: ErrorCode::InternalError,
169 message: format!("Unknown error: {}", e),
170 data: None,
171 }
172 }
173}
174
175pub fn map_error<E>(err: E) -> Error<String, serde_json::Value>
177where
178 E: Display,
179{
180 Error::<String, serde_json::Value>::from_std_error(err)
181}
182
183#[derive(thiserror::Error, Debug, PartialEq, Clone)]
188pub enum ErrorCode {
189 #[error("Invalid JSON was received by the server.")]
191 ParseError,
192 #[error("The JSON sent is not a valid Request object.")]
193 InvalidRequest,
194 #[error("The method does not exist / is not available.")]
195 MethodNotFound,
196 #[error("Invalid method parameter(s).")]
197 InvalidParams,
198 #[error("Internal JSON-RPC error.")]
199 InternalError,
200 #[error("Server error({0}),{1}")]
202 ServerError(i64, String),
203}
204
205impl serde::Serialize for ErrorCode {
206 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
207 where
208 S: Serializer,
209 {
210 match self {
211 Self::ParseError => serializer.serialize_i64(-32700),
212 Self::InvalidRequest => serializer.serialize_i64(-32600),
213 Self::MethodNotFound => serializer.serialize_i64(-32601),
214 Self::InvalidParams => serializer.serialize_i64(-32602),
215 Self::InternalError => serializer.serialize_i64(-32603),
216 Self::ServerError(code, _) => serializer.serialize_i64(*code),
217 }
218 }
219}
220
221impl<'de> serde::Deserialize<'de> for ErrorCode {
222 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
223 where
224 D: Deserializer<'de>,
225 {
226 let code = deserializer.deserialize_i64(visitor::ErrorCodeVisitor)?;
227
228 match code {
229 -32700 => Ok(ErrorCode::ParseError),
230 -32600 => Ok(ErrorCode::InvalidRequest),
231 -32601 => Ok(ErrorCode::MethodNotFound),
232 -32602 => Ok(ErrorCode::InvalidParams),
233 -32603 => Ok(ErrorCode::InternalError),
234 _ => {
235 if code <= -32000 && code >= -32099 {
237 Ok(ErrorCode::ServerError(code, "".to_owned()))
238 } else {
239 Err(anyhow::format_err!("Invalid JSONRPC error code {}", code))
240 .map_err(serde::de::Error::custom)
241 }
242 }
243 }
244 }
245}
246
247mod visitor {
248 use serde::de;
249 use std::fmt;
250
251 use crate::Version;
252
253 pub struct ErrorCodeVisitor;
254
255 impl<'de> de::Visitor<'de> for ErrorCodeVisitor {
256 type Value = i64;
257
258 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
259 formatter.write_str("Version string MUST be exactly 2.0")
260 }
261
262 fn visit_i8<E>(self, value: i8) -> Result<Self::Value, E>
263 where
264 E: de::Error,
265 {
266 Ok(i64::from(value))
267 }
268
269 fn visit_i32<E>(self, value: i32) -> Result<Self::Value, E>
270 where
271 E: de::Error,
272 {
273 Ok(i64::from(value))
274 }
275
276 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
277 where
278 E: de::Error,
279 {
280 Ok(value)
281 }
282 }
283
284 pub struct VersionVisitor;
285
286 impl<'de> de::Visitor<'de> for VersionVisitor {
287 type Value = Version;
288
289 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
290 formatter.write_str("an integer between -2^63 and 2^63")
291 }
292
293 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
294 where
295 E: de::Error,
296 {
297 if v != "2.0" {
298 return Err(anyhow::format_err!(
299 "Version string MUST be exactly 2.0, but got `{}`",
300 v
301 ))
302 .map_err(serde::de::Error::custom);
303 }
304
305 Ok(Version {})
306 }
307
308 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
309 where
310 E: de::Error,
311 {
312 if v.as_str() != "2.0" {
313 return Err(anyhow::format_err!(
314 "Version string MUST be exactly 2.0, but got `{}`",
315 v
316 ))
317 .map_err(serde::de::Error::custom);
318 }
319
320 Ok(Version {})
321 }
322 }
323}
324
325#[cfg(test)]
326mod tests {
327
328 use serde::{Deserialize, Serialize};
329 use serde_json::json;
330
331 use crate::Request;
332
333 #[test]
334 fn test_array_params() {
335 _ = pretty_env_logger::try_init();
336 #[derive(Default, Serialize, Deserialize, PartialEq, Debug)]
337 struct Params<S>(i32, S);
338
339 let request = Request {
340 method: "hello",
341 params: Params(10, "world"),
342 ..Default::default()
343 };
344
345 let json = serde_json::to_string(&request).unwrap();
346
347 assert_eq!(
348 json!({ "jsonrpc":"2.0", "method":"hello","params":[10, "world"]}).to_string(),
349 json
350 );
351
352 let request = serde_json::from_value::<Request<String, Params<String>>>(
353 json!({ "jsonrpc":"2.0", "method":"hello","params":[20, "hello"]}),
354 )
355 .expect("deserialize json");
356
357 assert_eq!(request.params.0, 20);
358 assert_eq!(request.params.1, "hello");
359 }
360
361 #[test]
362 fn test_version_check() {
363 #[derive(Default, Serialize, Deserialize, PartialEq, Debug)]
364 struct Params<S>(i32, S);
365
366 let request = serde_json::from_value::<Request<String, Params<String>>>(
367 json!({"jsonrpc":"3.0", "method":"hello","params":[10, "world"]}),
368 );
369
370 assert_eq!(
371 format!("{}", request.unwrap_err()),
372 "Version string MUST be exactly 2.0, but got `3.0`",
373 );
374 }
375
376 #[test]
377 fn test_tuple_params() {
378 let request = serde_json::from_value::<Request<String, (i32, String)>>(
379 json!({ "jsonrpc":"2.0", "method":"hello","params":[10, "world"]}),
380 )
381 .expect("parse tuple params");
382
383 assert_eq!(request.params.0, 10);
384 assert_eq!(request.params.1, "world");
385 }
386
387 #[test]
388 fn test_object_params() {
389 _ = pretty_env_logger::try_init();
390
391 #[derive(Default, Serialize, Deserialize, PartialEq, Debug)]
393 struct Params<S> {
394 id: i32,
395 name: S,
396 }
397
398 let request = Request {
399 method: "hello",
400 params: Params {
401 id: 10,
402 name: "world",
403 },
404 ..Default::default()
405 };
406
407 let json = serde_json::to_string(&request).unwrap();
408
409 assert_eq!(
411 json!({"jsonrpc":"2.0", "method":"hello","params":{"id":10, "name":"world"}})
412 .to_string(),
413 json
414 );
415
416 let request = serde_json::from_value::<Request<String, Params<String>>>(
417 json!({"jsonrpc":"2.0", "method":"hello","params":{"id": 20, "name":"hello"}}),
418 )
419 .expect("deserialize json");
420
421 assert_eq!(request.params.id, 20);
422 assert_eq!(request.params.name, "hello");
423 }
424}