1use std::fmt::{Display, Formatter};
2
3use futures::channel::mpsc::SendError;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6#[derive(Debug, Serialize, Deserialize, Default, PartialEq)]
10pub struct Request<S, P>
11where
12 S: AsRef<str>,
13{
14 #[serde(skip_serializing_if = "Option::is_none")]
18 pub id: Option<u64>,
19 pub jsonrpc: Version,
21 pub method: S,
25 pub params: P,
27}
28
29#[derive(Debug, Default, PartialEq)]
33pub struct Version;
34
35impl Serialize for Version {
36 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
37 where
38 S: Serializer,
39 {
40 serializer.serialize_str("2.0")
41 }
42}
43
44impl<'de> Deserialize<'de> for Version {
45 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46 where
47 D: Deserializer<'de>,
48 {
49 deserializer.deserialize_any(visitor::VersionVisitor)
50 }
51}
52
53#[derive(Debug, Serialize, Deserialize, Default, PartialEq)]
58pub struct Response<S, R, D>
59where
60 S: AsRef<str>,
61{
62 pub id: u64,
65 pub jsonrpc: Version,
67 #[serde(skip_serializing_if = "Option::is_none")]
71 pub result: Option<R>,
72
73 #[serde(skip_serializing_if = "Option::is_none")]
76 pub error: Option<Error<S, D>>,
77}
78
79#[derive(Debug, Serialize, Deserialize, Default, PartialEq)]
81struct JSONRPC<S, P, R, D> {
82 #[serde(skip_serializing_if = "Option::is_none")]
86 pub id: Option<usize>,
87 pub jsonrpc: Version,
89 #[serde(skip_serializing_if = "Option::is_none")]
93 pub method: Option<S>,
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub params: Option<P>,
97 #[serde(skip_serializing_if = "Option::is_none")]
101 pub result: Option<R>,
102
103 #[serde(skip_serializing_if = "Option::is_none")]
106 pub error: Option<Error<S, D>>,
107}
108
109#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, thiserror::Error)]
112pub struct Error<S, D> {
113 pub code: ErrorCode,
115 pub message: S,
118 pub data: Option<D>,
123}
124
125impl Display for Error<String, serde_json::Value> {
126 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
127 write!(f, "RPCError({}) {}", self.code, self.message)
128 }
129}
130
131impl From<serde_json::Error> for Error<String, serde_json::Value> {
132 fn from(err: serde_json::Error) -> Self {
133 Self {
134 code: ErrorCode::ParseError,
135 message: format!("Serialize/Deserialize json data error: {}", err),
136 data: None,
137 }
138 }
139}
140
141impl From<std::io::Error> for Error<String, serde_json::Value> {
142 fn from(err: std::io::Error) -> Self {
143 Self {
144 code: ErrorCode::InternalError,
145 message: format!("Underly io error: {}", err),
146 data: None,
147 }
148 }
149}
150
151impl From<SendError> for Error<String, serde_json::Value> {
152 fn from(err: SendError) -> Self {
153 Self {
154 code: ErrorCode::InternalError,
155 message: format!("Underly io error: {}", err),
156 data: None,
157 }
158 }
159}
160
161impl Error<String, serde_json::Value> {
162 pub fn from_std_error<E>(e: E) -> Self
163 where
164 E: Display,
165 {
166 Self {
167 code: ErrorCode::InternalError,
168 message: format!("Unknown error: {}", e),
169 data: None,
170 }
171 }
172}
173
174pub fn map_error<E>(err: E) -> Error<String, serde_json::Value>
176where
177 E: Display,
178{
179 Error::<String, serde_json::Value>::from_std_error(err)
180}
181
182#[derive(thiserror::Error, Debug, PartialEq, Clone)]
187pub enum ErrorCode {
188 #[error("Invalid JSON was received by the server.")]
190 ParseError,
191 #[error("The JSON sent is not a valid Request object.")]
192 InvalidRequest,
193 #[error("The method does not exist / is not available.")]
194 MethodNotFound,
195 #[error("Invalid method parameter(s).")]
196 InvalidParams,
197 #[error("Internal JSON-RPC error.")]
198 InternalError,
199 #[error("Server error({0}),{1}")]
201 ServerError(i64, String),
202}
203
204impl serde::Serialize for ErrorCode {
205 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206 where
207 S: Serializer,
208 {
209 match self {
210 Self::ParseError => serializer.serialize_i64(-32700),
211 Self::InvalidRequest => serializer.serialize_i64(-32600),
212 Self::MethodNotFound => serializer.serialize_i64(-32601),
213 Self::InvalidParams => serializer.serialize_i64(-32602),
214 Self::InternalError => serializer.serialize_i64(-32603),
215 Self::ServerError(code, _) => serializer.serialize_i64(*code),
216 }
217 }
218}
219
220impl<'de> serde::Deserialize<'de> for ErrorCode {
221 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222 where
223 D: Deserializer<'de>,
224 {
225 let code = deserializer.deserialize_i64(visitor::ErrorCodeVisitor)?;
226
227 match code {
228 -32700 => Ok(ErrorCode::ParseError),
229 -32600 => Ok(ErrorCode::InvalidRequest),
230 -32601 => Ok(ErrorCode::MethodNotFound),
231 -32602 => Ok(ErrorCode::InvalidParams),
232 -32603 => Ok(ErrorCode::InternalError),
233 _ => {
234 if code <= -32000 && code >= -32099 {
236 Ok(ErrorCode::ServerError(code, "".to_owned()))
237 } else {
238 Err(format!("Invalid JSONRPC error code {}", code))
239 .map_err(serde::de::Error::custom)
240 }
241 }
242 }
243 }
244}
245
246mod visitor {
247 use serde::de;
248 use std::fmt;
249
250 use super::Version;
251
252 pub struct ErrorCodeVisitor;
253
254 impl<'de> de::Visitor<'de> for ErrorCodeVisitor {
255 type Value = i64;
256
257 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
258 formatter.write_str("Version string MUST be exactly 2.0")
259 }
260
261 fn visit_i8<E>(self, value: i8) -> Result<Self::Value, E>
262 where
263 E: de::Error,
264 {
265 Ok(i64::from(value))
266 }
267
268 fn visit_i32<E>(self, value: i32) -> Result<Self::Value, E>
269 where
270 E: de::Error,
271 {
272 Ok(i64::from(value))
273 }
274
275 fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
276 where
277 E: de::Error,
278 {
279 Ok(value)
280 }
281 }
282
283 pub struct VersionVisitor;
284
285 impl<'de> de::Visitor<'de> for VersionVisitor {
286 type Value = Version;
287
288 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
289 formatter.write_str("an integer between -2^63 and 2^63")
290 }
291
292 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
293 where
294 E: de::Error,
295 {
296 if v != "2.0" {
297 return Err(format!(
298 "Version string MUST be exactly 2.0, but got `{}`",
299 v
300 ))
301 .map_err(serde::de::Error::custom);
302 }
303
304 Ok(Version {})
305 }
306
307 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
308 where
309 E: de::Error,
310 {
311 if v.as_str() != "2.0" {
312 return Err(format!(
313 "Version string MUST be exactly 2.0, but got `{}`",
314 v
315 ))
316 .map_err(serde::de::Error::custom);
317 }
318
319 Ok(Version {})
320 }
321 }
322}
323
324#[cfg(test)]
325mod tests {
326
327 use serde::{Deserialize, Serialize};
328 use serde_json::json;
329
330 use super::Request;
331
332 #[test]
333 fn test_array_params() {
334 _ = pretty_env_logger::try_init();
335 #[derive(Default, Serialize, Deserialize, PartialEq, Debug)]
336 struct Params<S>(i32, S);
337
338 let request = Request {
339 method: "hello",
340 params: Params(10, "world"),
341 ..Default::default()
342 };
343
344 let json = serde_json::to_string(&request).unwrap();
345
346 assert_eq!(
347 json!({ "jsonrpc":"2.0", "method":"hello","params":[10, "world"]}).to_string(),
348 json
349 );
350
351 let request = serde_json::from_value::<Request<String, Params<String>>>(
352 json!({ "jsonrpc":"2.0", "method":"hello","params":[20, "hello"]}),
353 )
354 .expect("deserialize json");
355
356 assert_eq!(request.params.0, 20);
357 assert_eq!(request.params.1, "hello");
358 }
359
360 #[test]
361 fn test_version_check() {
362 #[derive(Default, Serialize, Deserialize, PartialEq, Debug)]
363 struct Params<S>(i32, S);
364
365 let request = serde_json::from_value::<Request<String, Params<String>>>(
366 json!({"jsonrpc":"3.0", "method":"hello","params":[10, "world"]}),
367 );
368
369 assert_eq!(
370 format!("{}", request.unwrap_err()),
371 "Version string MUST be exactly 2.0, but got `3.0`",
372 );
373 }
374
375 #[test]
376 fn test_tuple_params() {
377 let request = serde_json::from_value::<Request<String, (i32, String)>>(
378 json!({ "jsonrpc":"2.0", "method":"hello","params":[10, "world"]}),
379 )
380 .expect("parse tuple params");
381
382 assert_eq!(request.params.0, 10);
383 assert_eq!(request.params.1, "world");
384 }
385
386 #[test]
387 fn test_object_params() {
388 _ = pretty_env_logger::try_init();
389
390 #[derive(Default, Serialize, Deserialize, PartialEq, Debug)]
392 struct Params<S> {
393 id: i32,
394 name: S,
395 }
396
397 let request = Request {
398 method: "hello",
399 params: Params {
400 id: 10,
401 name: "world",
402 },
403 ..Default::default()
404 };
405
406 let json = serde_json::to_string(&request).unwrap();
407
408 assert_eq!(
410 json!({"jsonrpc":"2.0", "method":"hello","params":{"id":10, "name":"world"}})
411 .to_string(),
412 json
413 );
414
415 let request = serde_json::from_value::<Request<String, Params<String>>>(
416 json!({"jsonrpc":"2.0", "method":"hello","params":{"id": 20, "name":"hello"}}),
417 )
418 .expect("deserialize json");
419
420 assert_eq!(request.params.id, 20);
421 assert_eq!(request.params.name, "hello");
422 }
423}