1#![warn(
2 clippy::all,
3 clippy::dbg_macro,
4 clippy::todo,
5 clippy::empty_enum,
6 clippy::enum_glob_use,
7 clippy::mem_forget,
8 clippy::unused_self,
9 clippy::filter_map_next,
10 clippy::needless_continue,
11 clippy::needless_borrow,
12 clippy::match_wildcard_for_single_variants,
13 clippy::if_let_mutex,
14 unexpected_cfgs,
15 clippy::await_holding_lock,
16 clippy::imprecise_flops,
17 clippy::suboptimal_flops,
18 clippy::lossy_float_literal,
19 clippy::rest_pat_in_fully_bound_structs,
20 clippy::fn_params_excessive_bools,
21 clippy::exit,
22 clippy::inefficient_to_string,
23 clippy::linkedlist,
24 clippy::macro_use_imports,
25 clippy::option_option,
26 clippy::verbose_file_reads,
27 clippy::unnested_or_patterns,
28 clippy::str_to_string,
29 rust_2018_idioms,
30 future_incompatible,
31 nonstandard_style,
32 missing_debug_implementations
33)]
34#![deny(unreachable_pub)]
35#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
36
37use std::borrow::Cow;
38
39use axum::body::Bytes;
40use axum::extract::{FromRequest, Request};
41use axum::http::{header, HeaderMap};
42use axum::response::{IntoResponse, Response};
43use axum::Json;
44use cfg_if::cfg_if;
45use serde::de::DeserializeOwned;
46use serde::{Deserialize, Serialize};
47
48cfg_if! {
49 if #[cfg(feature = "serde_json")] {
50 pub use serde_json::Value;
51 pub mod error;
52 use crate::error::{JsonRpcError, JsonRpcErrorReason};
53 }
54 else if #[cfg(feature = "simd")] {
55 pub use simd_json::OwnedValue as Value;
56 pub mod error;
57 use crate::error::{JsonRpcError, JsonRpcErrorReason};
58 }
59 else {
60 compile_error!("features `serde_json` and `simd` are mutually exclusive");
61 }
62}
63
64pub type JrpcResult = Result<JsonRpcResponse, JsonRpcResponse>;
66
67#[derive(Debug)]
68pub struct JsonRpcRequest {
69 pub id: Id,
70 pub method: String,
71 pub params: Value,
72}
73
74impl Serialize for JsonRpcRequest {
75 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
76 where
77 S: serde::Serializer,
78 {
79 #[derive(Serialize)]
80 struct Helper<'a> {
81 jsonrpc: &'static str,
82 id: Id,
83 method: &'a str,
84 params: &'a Value,
85 }
86
87 Helper {
88 jsonrpc: JSONRPC,
89 id: self.id.clone(),
90 method: &self.method,
91 params: &self.params,
92 }
93 .serialize(serializer)
94 }
95}
96
97impl<'de> Deserialize<'de> for JsonRpcRequest {
98 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
99 where
100 D: serde::Deserializer<'de>,
101 {
102 use serde::de::Error;
103
104 #[derive(Deserialize)]
105 struct Helper<'a> {
106 #[serde(borrow)]
107 jsonrpc: Cow<'a, str>,
108 id: Id,
109 method: String,
110 params: Option<Value>,
111 }
112
113 let helper = Helper::deserialize(deserializer)?;
114 if helper.jsonrpc == JSONRPC {
115 Ok(Self {
116 id: helper.id,
117 method: helper.method,
118 params: helper.params.unwrap_or(Value::default()),
119 })
120 } else {
121 Err(D::Error::custom("Unknown jsonrpc version"))
122 }
123 }
124}
125
126#[derive(Clone, Debug)]
127pub struct JsonRpcExtractor {
145 pub parsed: Value,
146 pub method: String,
147 pub id: Id,
148}
149
150impl JsonRpcExtractor {
151 pub fn get_answer_id(&self) -> Id {
152 self.id.clone()
153 }
154
155 pub fn parse_params<T: DeserializeOwned>(self) -> Result<T, JsonRpcResponse> {
156 cfg_if::cfg_if! {
157 if #[cfg(feature = "serde_json")] {
158 match serde_json::from_value(self.parsed){
159 Ok(v) => Ok(v),
160 Err(e) => {
161 let error = JsonRpcError::new(
162 JsonRpcErrorReason::InvalidParams,
163 e.to_string(),
164 Value::Null,
165 );
166 Err(JsonRpcResponse::error(self.id, error))
167 }
168 }
169 } else if #[cfg(feature = "simd")] {
170 match simd_json::serde::from_owned_value(self.parsed){
171 Ok(v) => Ok(v),
172 Err(e) => {
173 let error = JsonRpcError::new(
174 JsonRpcErrorReason::InvalidParams,
175 e.to_string(),
176 Value::default(),
177 );
178 Err(JsonRpcResponse::error(self.id, error))
179 }
180 }
181 }
182 }
183 }
184
185 pub fn method(&self) -> &str {
186 &self.method
187 }
188
189 pub fn method_not_found(&self, method: &str) -> JsonRpcResponse {
190 let error = JsonRpcError::new(
191 JsonRpcErrorReason::MethodNotFound,
192 format!("Method `{}` not found", method),
193 Value::default(),
194 );
195
196 JsonRpcResponse::error(self.id.clone(), error)
197 }
198}
199
200impl<S> FromRequest<S> for JsonRpcExtractor
201where
202 Bytes: FromRequest<S>,
203 S: Send + Sync,
204{
205 type Rejection = JsonRpcResponse;
206
207 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
208 if !json_content_type(req.headers()) {
209 return Err(JsonRpcResponse {
210 id: Id::None(()),
211 result: JsonRpcAnswer::Error(JsonRpcError::new(
212 JsonRpcErrorReason::InvalidRequest,
213 "Invalid content type".to_owned(),
214 Value::default(),
215 )),
216 });
217 }
218
219 #[allow(unused_mut)]
220 let mut bytes = match Bytes::from_request(req, state).await {
221 Ok(a) => a.to_vec(),
222 Err(_) => {
223 return Err(JsonRpcResponse {
224 id: Id::None(()),
225 result: JsonRpcAnswer::Error(JsonRpcError::new(
226 JsonRpcErrorReason::InvalidRequest,
227 "Invalid request".to_owned(),
228 Value::default(),
229 )),
230 })
231 }
232 };
233
234 cfg_if!(
235 if #[cfg(feature = "serde_json")] {
236 let parsed: JsonRpcRequest = match serde_json::from_slice(&bytes){
237 Ok(a) => a,
238 Err(e) => {
239 return Err(JsonRpcResponse {
240 id: Id::None(()),
241 result: JsonRpcAnswer::Error(JsonRpcError::new(
242 JsonRpcErrorReason::InvalidRequest,
243 e.to_string(),
244 Value::default(),
245 )),
246 })
247 }
248 };
249 } else if #[cfg(feature = "simd")] {
250 let parsed: JsonRpcRequest = match simd_json::from_slice(&mut bytes){
251 Ok(a) => a,
252 Err(e) => {
253 return Err(JsonRpcResponse {
254 id: Id::None(()),
255 result: JsonRpcAnswer::Error(JsonRpcError::new(
256 JsonRpcErrorReason::InvalidRequest,
257 e.to_string(),
258 Value::default(),
259 )),
260 })
261 }
262 };
263 }
264 );
265
266 Ok(Self {
267 parsed: parsed.params,
268 method: parsed.method,
269 id: parsed.id,
270 })
271 }
272}
273
274fn json_content_type(headers: &HeaderMap) -> bool {
275 let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
276 content_type
277 } else {
278 return false;
279 };
280
281 let content_type = if let Ok(content_type) = content_type.to_str() {
282 content_type
283 } else {
284 return false;
285 };
286
287 let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
288 mime
289 } else {
290 return false;
291 };
292
293 let is_json_content_type = mime.type_() == "application"
294 && (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json"));
295
296 is_json_content_type
297}
298
299#[derive(Debug, Clone, PartialEq)]
300pub struct JsonRpcResponse {
302 pub result: JsonRpcAnswer,
304 pub id: Id,
306}
307
308impl JsonRpcResponse {
309 fn new<ID>(id: ID, result: JsonRpcAnswer) -> Self
310 where
311 Id: From<ID>,
312 {
313 Self {
314 result,
315 id: id.into(),
316 }
317 }
318
319 pub fn success<T, ID>(id: ID, result: T) -> Self
322 where
323 T: Serialize,
324 Id: From<ID>,
325 {
326 cfg_if::cfg_if! {
327 if #[cfg(feature = "serde_json")] {
328 match serde_json::to_value(result) {
329 Ok(v) => JsonRpcResponse::new(id, JsonRpcAnswer::Result(v)),
330 Err(e) => {
331 let err = JsonRpcError::new(
332 JsonRpcErrorReason::InternalError,
333 e.to_string(),
334 Value::Null,
335 );
336 JsonRpcResponse::error(id, err)
337 }
338 }
339 } else if #[cfg(feature = "simd")] {
340 match simd_json::serde::to_owned_value(result) {
341 Ok(v) => JsonRpcResponse::new(id, JsonRpcAnswer::Result(v)),
342 Err(e) => {
343 let err = JsonRpcError::new(
344 JsonRpcErrorReason::InternalError,
345 e.to_string(),
346 Value::default(),
347 );
348 JsonRpcResponse::error(id, err)
349 }
350 }
351 }
352 }
353 }
354
355 pub fn error<ID>(id: ID, error: JsonRpcError) -> Self
356 where
357 Id: From<ID>,
358 {
359 let id = id.into();
360 JsonRpcResponse {
361 result: JsonRpcAnswer::Error(error),
362 id,
363 }
364 }
365}
366
367impl Serialize for JsonRpcResponse {
368 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
369 where
370 S: serde::Serializer,
371 {
372 #[derive(Serialize)]
373 struct Helper<'a> {
374 jsonrpc: &'static str,
375 #[serde(flatten)]
376 result: &'a JsonRpcAnswer,
377 id: Id,
378 }
379
380 Helper {
381 jsonrpc: JSONRPC,
382 result: &self.result,
383 id: self.id.clone(),
384 }
385 .serialize(serializer)
386 }
387}
388
389impl<'de> Deserialize<'de> for JsonRpcResponse {
390 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
391 where
392 D: serde::Deserializer<'de>,
393 {
394 use serde::de::Error;
395
396 #[derive(Deserialize)]
397 struct Helper<'a> {
398 #[serde(borrow)]
399 jsonrpc: Cow<'a, str>,
400 #[serde(flatten)]
401 result: JsonRpcAnswer,
402 id: Id,
403 }
404
405 let helper = Helper::deserialize(deserializer)?;
406 if helper.jsonrpc == JSONRPC {
407 Ok(Self {
408 result: helper.result,
409 id: helper.id,
410 })
411 } else {
412 Err(D::Error::custom("Unknown jsonrpc version"))
413 }
414 }
415}
416
417impl IntoResponse for JsonRpcResponse {
418 fn into_response(self) -> Response {
419 Json(self).into_response()
420 }
421}
422
423#[derive(Serialize, Clone, Debug, Deserialize, PartialEq)]
424#[serde(rename_all = "lowercase")]
425pub enum JsonRpcAnswer {
427 Result(Value),
428 Error(JsonRpcError),
429}
430
431const JSONRPC: &str = "2.0";
432
433#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Hash)]
437#[serde(untagged)]
438pub enum Id {
439 Num(i64),
440 Str(String),
441 None(()),
442}
443
444impl From<()> for Id {
445 fn from(val: ()) -> Self {
446 Id::None(val)
447 }
448}
449
450impl From<i64> for Id {
451 fn from(val: i64) -> Self {
452 Id::Num(val)
453 }
454}
455
456impl From<String> for Id {
457 fn from(val: String) -> Self {
458 Id::Str(val)
459 }
460}
461
462#[cfg(test)]
463#[cfg(all(feature = "anyhow_error", feature = "serde_json"))]
464mod test {
465 use crate::{
466 Deserialize, JrpcResult, JsonRpcAnswer, JsonRpcError, JsonRpcErrorReason, JsonRpcExtractor,
467 JsonRpcRequest, JsonRpcResponse,
468 };
469 use axum::routing::post;
470 use serde::Serialize;
471 use serde_json::Value;
472
473 #[tokio::test]
474 async fn test() {
475 use axum::http::StatusCode;
476 use axum::Router;
477 use axum_test::TestServer;
478
479 let app = Router::new().route("/", post(handler));
481
482 let client = TestServer::new(app).unwrap();
484
485 let res = client
486 .post("/")
487 .json(&JsonRpcRequest {
488 id: 0.into(),
489 method: "add".to_owned(),
490 params: serde_json::to_value(Test { a: 0, b: 111 }).unwrap(),
491 })
492 .await;
493 assert_eq!(res.status_code(), StatusCode::OK);
494 let response = res.json::<JsonRpcResponse>();
495 assert_eq!(response.result, JsonRpcAnswer::Result(111.into()));
496
497 let res = client
498 .post("/")
499 .json(&JsonRpcRequest {
500 id: 0.into(),
501 method: "lol".to_owned(),
502 params: serde_json::to_value(()).unwrap(),
503 })
504 .await;
505
506 assert_eq!(res.status_code(), StatusCode::OK);
507
508 let response = res.json::<JsonRpcResponse>();
509
510 let error = JsonRpcError::new(
511 JsonRpcErrorReason::MethodNotFound,
512 format!("Method `{}` not found", "lol"),
513 Value::Null,
514 );
515
516 let error = JsonRpcResponse::error(0, error);
517
518 assert_eq!(
519 serde_json::to_value(error).unwrap(),
520 serde_json::to_value(response).unwrap()
521 );
522 }
523
524 async fn handler(value: JsonRpcExtractor) -> JrpcResult {
525 let answer_id = value.get_answer_id();
526 println!("{:?}", value);
527 match value.method.as_str() {
528 "add" => {
529 let request: Test = value.parse_params()?;
530 let result = request.a + request.b;
531 Ok(JsonRpcResponse::success(answer_id, result))
532 }
533 "sub" => {
534 let result: [i32; 2] = value.parse_params()?;
535 let result = match failing_sub(result[0], result[1]).await {
536 Ok(result) => result,
537 Err(e) => return Err(JsonRpcResponse::error(answer_id, e.into())),
538 };
539 Ok(JsonRpcResponse::success(answer_id, result))
540 }
541 "div" => {
542 let result: [i32; 2] = value.parse_params()?;
543 let result = match failing_div(result[0], result[1]).await {
544 Ok(result) => result,
545 Err(e) => return Err(JsonRpcResponse::error(answer_id, e.into())),
546 };
547
548 Ok(JsonRpcResponse::success(answer_id, result))
549 }
550 method => Ok(value.method_not_found(method)),
551 }
552 }
553
554 async fn failing_sub(a: i32, b: i32) -> anyhow::Result<i32> {
555 anyhow::ensure!(a > b, "a must be greater than b");
556 Ok(a - b)
557 }
558
559 async fn failing_div(a: i32, b: i32) -> Result<i32, CustomError> {
560 if b == 0 {
561 Err(CustomError::DivideByZero)
562 } else {
563 Ok(a / b)
564 }
565 }
566
567 #[derive(Deserialize, Serialize, Debug)]
568 struct Test {
569 a: i32,
570 b: i32,
571 }
572
573 #[derive(Debug, thiserror::Error)]
574 enum CustomError {
575 #[error("Divisor must not be equal to 0")]
576 DivideByZero,
577 }
578
579 impl From<CustomError> for JsonRpcError {
580 fn from(error: CustomError) -> Self {
581 JsonRpcError::new(
582 JsonRpcErrorReason::ServerError(-32099),
583 error.to_string(),
584 serde_json::Value::Null,
585 )
586 }
587 }
588}