1#![warn(
2 clippy::all,
3 clippy::dbg_macro,
4 clippy::todo,
5 clippy::empty_enum,
6 clippy::enum_glob_use,
7 clippy::mem_forget,
8 clippy::unused_self,
9 clippy::filter_map_next,
10 clippy::needless_continue,
11 clippy::needless_borrow,
12 clippy::match_wildcard_for_single_variants,
13 clippy::if_let_mutex,
14 unexpected_cfgs,
15 clippy::await_holding_lock,
16 clippy::match_on_vec_items,
17 clippy::imprecise_flops,
18 clippy::suboptimal_flops,
19 clippy::lossy_float_literal,
20 clippy::rest_pat_in_fully_bound_structs,
21 clippy::fn_params_excessive_bools,
22 clippy::exit,
23 clippy::inefficient_to_string,
24 clippy::linkedlist,
25 clippy::macro_use_imports,
26 clippy::option_option,
27 clippy::verbose_file_reads,
28 clippy::unnested_or_patterns,
29 clippy::str_to_string,
30 rust_2018_idioms,
31 future_incompatible,
32 nonstandard_style,
33 missing_debug_implementations
34)]
35#![deny(unreachable_pub)]
36#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
37
38use std::borrow::Cow;
39
40use axum::body::Bytes;
41use axum::extract::{FromRequest, Request};
42use axum::http::{header, HeaderMap};
43use axum::response::{IntoResponse, Response};
44use axum::Json;
45use cfg_if::cfg_if;
46use serde::de::DeserializeOwned;
47use serde::{Deserialize, Serialize};
48
49cfg_if! {
50 if #[cfg(feature = "serde_json")] {
51 pub use serde_json::Value;
52 pub mod error;
53 use crate::error::{JsonRpcError, JsonRpcErrorReason};
54 }
55 else if #[cfg(feature = "simd")] {
56 pub use simd_json::OwnedValue as Value;
57 pub mod error;
58 use crate::error::{JsonRpcError, JsonRpcErrorReason};
59 }
60 else {
61 compile_error!("features `serde_json` and `simd` are mutually exclusive");
62 }
63}
64
65pub type JrpcResult = Result<JsonRpcResponse, JsonRpcResponse>;
67
68#[derive(Debug)]
69pub struct JsonRpcRequest {
70 pub id: Id,
71 pub method: String,
72 pub params: Value,
73}
74
75impl Serialize for JsonRpcRequest {
76 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: serde::Serializer,
79 {
80 #[derive(Serialize)]
81 struct Helper<'a> {
82 jsonrpc: &'static str,
83 id: Id,
84 method: &'a str,
85 params: &'a Value,
86 }
87
88 Helper {
89 jsonrpc: JSONRPC,
90 id: self.id.clone(),
91 method: &self.method,
92 params: &self.params,
93 }
94 .serialize(serializer)
95 }
96}
97
98impl<'de> Deserialize<'de> for JsonRpcRequest {
99 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100 where
101 D: serde::Deserializer<'de>,
102 {
103 use serde::de::Error;
104
105 #[derive(Deserialize)]
106 struct Helper<'a> {
107 #[serde(borrow)]
108 jsonrpc: Cow<'a, str>,
109 id: Id,
110 method: String,
111 params: Value,
112 }
113
114 let helper = Helper::deserialize(deserializer)?;
115 if helper.jsonrpc == JSONRPC {
116 Ok(Self {
117 id: helper.id,
118 method: helper.method,
119 params: helper.params,
120 })
121 } else {
122 Err(D::Error::custom("Unknown jsonrpc version"))
123 }
124 }
125}
126
127#[derive(Clone, Debug)]
128pub struct JsonRpcExtractor {
146 pub parsed: Value,
147 pub method: String,
148 pub id: Id,
149}
150
151impl JsonRpcExtractor {
152 pub fn get_answer_id(&self) -> Id {
153 self.id.clone()
154 }
155
156 pub fn parse_params<T: DeserializeOwned>(self) -> Result<T, JsonRpcResponse> {
157 cfg_if::cfg_if! {
158 if #[cfg(feature = "simd")] {
159 match simd_json::serde::from_owned_value(self.parsed){
160 Ok(v) => Ok(v),
161 Err(e) => {
162 let error = JsonRpcError::new(
163 JsonRpcErrorReason::InvalidParams,
164 e.to_string(),
165 Value::default(),
166 );
167 Err(JsonRpcResponse::error(self.id, error))
168 }
169
170 }
171 } else if #[cfg(feature = "serde_json")] {
172 match serde_json::from_value(self.parsed){
173 Ok(v) => Ok(v),
174 Err(e) => {
175 let error = JsonRpcError::new(
176 JsonRpcErrorReason::InvalidParams,
177 e.to_string(),
178 Value::Null,
179 );
180 Err(JsonRpcResponse::error(self.id, error))
181 }
182 }
183 }
184 }
185 }
186
187 pub fn method(&self) -> &str {
188 &self.method
189 }
190
191 pub fn method_not_found(&self, method: &str) -> JsonRpcResponse {
192 let error = JsonRpcError::new(
193 JsonRpcErrorReason::MethodNotFound,
194 format!("Method `{}` not found", method),
195 Value::default(),
196 );
197
198 JsonRpcResponse::error(self.id.clone(), error)
199 }
200}
201
202impl<S> FromRequest<S> for JsonRpcExtractor
203where
204 Bytes: FromRequest<S>,
205 S: Send + Sync,
206{
207 type Rejection = JsonRpcResponse;
208
209 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
210 if !json_content_type(req.headers()) {
211 return Err(JsonRpcResponse {
212 id: Id::None(()),
213 result: JsonRpcAnswer::Error(JsonRpcError::new(
214 JsonRpcErrorReason::InvalidRequest,
215 "Invalid content type".to_owned(),
216 Value::default(),
217 )),
218 });
219 }
220
221 #[allow(unused_mut)]
222 let mut bytes = match Bytes::from_request(req, state).await {
223 Ok(a) => a.to_vec(),
224 Err(_) => {
225 return Err(JsonRpcResponse {
226 id: Id::None(()),
227 result: JsonRpcAnswer::Error(JsonRpcError::new(
228 JsonRpcErrorReason::InvalidRequest,
229 "Invalid request".to_owned(),
230 Value::default(),
231 )),
232 })
233 }
234 };
235
236 cfg_if!(
237 if #[cfg(feature = "simd")] {
238 let parsed: JsonRpcRequest = match simd_json::from_slice(&mut bytes){
239 Ok(a) => a,
240 Err(e) => {
241 return Err(JsonRpcResponse {
242 id: Id::None(()),
243 result: JsonRpcAnswer::Error(JsonRpcError::new(
244 JsonRpcErrorReason::InvalidRequest,
245 e.to_string(),
246 Value::default(),
247 )),
248 })
249 }
250 };
251 } else if #[cfg(feature = "serde_json")] {
252 let parsed: JsonRpcRequest = match serde_json::from_slice(&bytes){
253 Ok(a) => a,
254 Err(e) => {
255 return Err(JsonRpcResponse {
256 id: Id::None(()),
257 result: JsonRpcAnswer::Error(JsonRpcError::new(
258 JsonRpcErrorReason::InvalidRequest,
259 e.to_string(),
260 Value::default(),
261 )),
262 })
263 }
264 };
265 }
266 );
267
268 Ok(Self {
269 parsed: parsed.params,
270 method: parsed.method,
271 id: parsed.id,
272 })
273 }
274}
275
276fn json_content_type(headers: &HeaderMap) -> bool {
277 let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
278 content_type
279 } else {
280 return false;
281 };
282
283 let content_type = if let Ok(content_type) = content_type.to_str() {
284 content_type
285 } else {
286 return false;
287 };
288
289 let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
290 mime
291 } else {
292 return false;
293 };
294
295 let is_json_content_type = mime.type_() == "application"
296 && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json"));
297
298 is_json_content_type
299}
300
301#[derive(Debug, Clone, PartialEq)]
302pub struct JsonRpcResponse {
304 pub result: JsonRpcAnswer,
306 pub id: Id,
308}
309
310impl JsonRpcResponse {
311 fn new<ID>(id: ID, result: JsonRpcAnswer) -> Self
312 where
313 Id: From<ID>,
314 {
315 Self {
316 result,
317 id: id.into(),
318 }
319 }
320
321 pub fn success<T, ID>(id: ID, result: T) -> Self
324 where
325 T: Serialize,
326 Id: From<ID>,
327 {
328 cfg_if::cfg_if! {
329 if #[cfg(feature = "simd")] {
330 match simd_json::serde::to_owned_value(result) {
331 Ok(v) => JsonRpcResponse::new(id, JsonRpcAnswer::Result(v)),
332 Err(e) => {
333 let err = JsonRpcError::new(
334 JsonRpcErrorReason::InternalError,
335 e.to_string(),
336 Value::default(),
337 );
338 JsonRpcResponse::error(id, err)
339 }
340 }
341 } else if #[cfg(feature = "serde_json")] {
342 match serde_json::to_value(result) {
343 Ok(v) => JsonRpcResponse::new(id, JsonRpcAnswer::Result(v)),
344 Err(e) => {
345 let err = JsonRpcError::new(
346 JsonRpcErrorReason::InternalError,
347 e.to_string(),
348 Value::Null,
349 );
350 JsonRpcResponse::error(id, err)
351 }
352 }
353 }
354 }
355 }
356
357 pub fn error<ID>(id: ID, error: JsonRpcError) -> Self
358 where
359 Id: From<ID>,
360 {
361 let id = id.into();
362 JsonRpcResponse {
363 result: JsonRpcAnswer::Error(error),
364 id,
365 }
366 }
367}
368
369impl Serialize for JsonRpcResponse {
370 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
371 where
372 S: serde::Serializer,
373 {
374 #[derive(Serialize)]
375 struct Helper<'a> {
376 jsonrpc: &'static str,
377 #[serde(flatten)]
378 result: &'a JsonRpcAnswer,
379 id: Id,
380 }
381
382 Helper {
383 jsonrpc: JSONRPC,
384 result: &self.result,
385 id: self.id.clone(),
386 }
387 .serialize(serializer)
388 }
389}
390
391impl<'de> Deserialize<'de> for JsonRpcResponse {
392 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
393 where
394 D: serde::Deserializer<'de>,
395 {
396 use serde::de::Error;
397
398 #[derive(Deserialize)]
399 struct Helper<'a> {
400 #[serde(borrow)]
401 jsonrpc: Cow<'a, str>,
402 #[serde(flatten)]
403 result: JsonRpcAnswer,
404 id: Id,
405 }
406
407 let helper = Helper::deserialize(deserializer)?;
408 if helper.jsonrpc == JSONRPC {
409 Ok(Self {
410 result: helper.result,
411 id: helper.id,
412 })
413 } else {
414 Err(D::Error::custom("Unknown jsonrpc version"))
415 }
416 }
417}
418
419impl IntoResponse for JsonRpcResponse {
420 fn into_response(self) -> Response {
421 Json(self).into_response()
422 }
423}
424
425#[derive(Serialize, Clone, Debug, Deserialize, PartialEq)]
426#[serde(rename_all = "lowercase")]
427pub enum JsonRpcAnswer {
429 Result(Value),
430 Error(JsonRpcError),
431}
432
433const JSONRPC: &str = "2.0";
434
435#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Hash)]
439#[serde(untagged)]
440pub enum Id {
441 Num(i64),
442 Str(String),
443 None(()),
444}
445
446impl From<()> for Id {
447 fn from(val: ()) -> Self {
448 Id::None(val)
449 }
450}
451
452impl From<i64> for Id {
453 fn from(val: i64) -> Self {
454 Id::Num(val)
455 }
456}
457
458impl From<String> for Id {
459 fn from(val: String) -> Self {
460 Id::Str(val)
461 }
462}
463
464#[cfg(test)]
465#[cfg(all(feature = "anyhow_error", feature = "serde_json"))]
466mod test {
467 use crate::{
468 Deserialize, JrpcResult, JsonRpcAnswer, JsonRpcError, JsonRpcErrorReason, JsonRpcExtractor,
469 JsonRpcRequest, JsonRpcResponse,
470 };
471 use axum::routing::post;
472 use serde::Serialize;
473 use serde_json::Value;
474
475 #[tokio::test]
476 async fn test() {
477 use axum::http::StatusCode;
478 use axum::Router;
479 use axum_test::TestServer;
480
481 let app = Router::new().route("/", post(handler));
483
484 let client = TestServer::new(app).unwrap();
486
487 let res = client
488 .post("/")
489 .json(&JsonRpcRequest {
490 id: 0.into(),
491 method: "add".to_owned(),
492 params: serde_json::to_value(Test { a: 0, b: 111 }).unwrap(),
493 })
494 .await;
495 assert_eq!(res.status_code(), StatusCode::OK);
496 let response = res.json::<JsonRpcResponse>();
497 assert_eq!(response.result, JsonRpcAnswer::Result(111.into()));
498
499 let res = client
500 .post("/")
501 .json(&JsonRpcRequest {
502 id: 0.into(),
503 method: "lol".to_owned(),
504 params: serde_json::to_value(()).unwrap(),
505 })
506 .await;
507
508 assert_eq!(res.status_code(), StatusCode::OK);
509
510 let response = res.json::<JsonRpcResponse>();
511
512 let error = JsonRpcError::new(
513 JsonRpcErrorReason::MethodNotFound,
514 format!("Method `{}` not found", "lol"),
515 Value::Null,
516 );
517
518 let error = JsonRpcResponse::error(0, error);
519
520 assert_eq!(
521 serde_json::to_value(error).unwrap(),
522 serde_json::to_value(response).unwrap()
523 );
524 }
525
526 async fn handler(value: JsonRpcExtractor) -> JrpcResult {
527 let answer_id = value.get_answer_id();
528 println!("{:?}", value);
529 match value.method.as_str() {
530 "add" => {
531 let request: Test = value.parse_params()?;
532 let result = request.a + request.b;
533 Ok(JsonRpcResponse::success(answer_id, result))
534 }
535 "sub" => {
536 let result: [i32; 2] = value.parse_params()?;
537 let result = match failing_sub(result[0], result[1]).await {
538 Ok(result) => result,
539 Err(e) => return Err(JsonRpcResponse::error(answer_id, e.into())),
540 };
541 Ok(JsonRpcResponse::success(answer_id, result))
542 }
543 "div" => {
544 let result: [i32; 2] = value.parse_params()?;
545 let result = match failing_div(result[0], result[1]).await {
546 Ok(result) => result,
547 Err(e) => return Err(JsonRpcResponse::error(answer_id, e.into())),
548 };
549
550 Ok(JsonRpcResponse::success(answer_id, result))
551 }
552 method => Ok(value.method_not_found(method)),
553 }
554 }
555
556 async fn failing_sub(a: i32, b: i32) -> anyhow::Result<i32> {
557 anyhow::ensure!(a > b, "a must be greater than b");
558 Ok(a - b)
559 }
560
561 async fn failing_div(a: i32, b: i32) -> Result<i32, CustomError> {
562 if b == 0 {
563 Err(CustomError::DivideByZero)
564 } else {
565 Ok(a / b)
566 }
567 }
568
569 #[derive(Deserialize, Serialize, Debug)]
570 struct Test {
571 a: i32,
572 b: i32,
573 }
574
575 #[derive(Debug, thiserror::Error)]
576 enum CustomError {
577 #[error("Divisor must not be equal to 0")]
578 DivideByZero,
579 }
580
581 impl From<CustomError> for JsonRpcError {
582 fn from(error: CustomError) -> Self {
583 JsonRpcError::new(
584 JsonRpcErrorReason::ServerError(-32099),
585 error.to_string(),
586 serde_json::Value::Null,
587 )
588 }
589 }
590}