axum_jrpc/
lib.rs

1#![warn(
2    clippy::all,
3    clippy::dbg_macro,
4    clippy::todo,
5    clippy::empty_enum,
6    clippy::enum_glob_use,
7    clippy::mem_forget,
8    clippy::unused_self,
9    clippy::filter_map_next,
10    clippy::needless_continue,
11    clippy::needless_borrow,
12    clippy::match_wildcard_for_single_variants,
13    clippy::if_let_mutex,
14    unexpected_cfgs,
15    clippy::await_holding_lock,
16    clippy::imprecise_flops,
17    clippy::suboptimal_flops,
18    clippy::lossy_float_literal,
19    clippy::rest_pat_in_fully_bound_structs,
20    clippy::fn_params_excessive_bools,
21    clippy::exit,
22    clippy::inefficient_to_string,
23    clippy::linkedlist,
24    clippy::macro_use_imports,
25    clippy::option_option,
26    clippy::verbose_file_reads,
27    clippy::unnested_or_patterns,
28    clippy::str_to_string,
29    rust_2018_idioms,
30    future_incompatible,
31    nonstandard_style,
32    missing_debug_implementations
33)]
34#![deny(unreachable_pub)]
35#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
36
37use std::borrow::Cow;
38
39use axum::body::Bytes;
40use axum::extract::{FromRequest, Request};
41use axum::http::{header, HeaderMap};
42use axum::response::{IntoResponse, Response};
43use axum::Json;
44use cfg_if::cfg_if;
45use serde::de::DeserializeOwned;
46use serde::{Deserialize, Serialize};
47
48cfg_if! {
49    if #[cfg(feature = "serde_json")] {
50        pub use serde_json::Value;
51        pub mod error;
52        use crate::error::{JsonRpcError, JsonRpcErrorReason};
53    }
54    else if #[cfg(feature = "simd")] {
55        pub use simd_json::OwnedValue as Value;
56        pub mod error;
57        use crate::error::{JsonRpcError, JsonRpcErrorReason};
58    }
59    else {
60        compile_error!("features `serde_json` and `simd` are mutually exclusive");
61    }
62}
63
64/// Hack until [try_trait_v2](https://github.com/rust-lang/rust/issues/84277) is not stabilized
65pub type JrpcResult = Result<JsonRpcResponse, JsonRpcResponse>;
66
67#[derive(Debug)]
68pub struct JsonRpcRequest {
69    pub id: Id,
70    pub method: String,
71    pub params: Value,
72}
73
74impl Serialize for JsonRpcRequest {
75    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
76    where
77        S: serde::Serializer,
78    {
79        #[derive(Serialize)]
80        struct Helper<'a> {
81            jsonrpc: &'static str,
82            id: Id,
83            method: &'a str,
84            params: &'a Value,
85        }
86
87        Helper {
88            jsonrpc: JSONRPC,
89            id: self.id.clone(),
90            method: &self.method,
91            params: &self.params,
92        }
93        .serialize(serializer)
94    }
95}
96
97impl<'de> Deserialize<'de> for JsonRpcRequest {
98    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
99    where
100        D: serde::Deserializer<'de>,
101    {
102        use serde::de::Error;
103
104        #[derive(Deserialize)]
105        struct Helper<'a> {
106            #[serde(borrow)]
107            jsonrpc: Cow<'a, str>,
108            id: Id,
109            method: String,
110            params: Option<Value>,
111        }
112
113        let helper = Helper::deserialize(deserializer)?;
114        if helper.jsonrpc == JSONRPC {
115            Ok(Self {
116                id: helper.id,
117                method: helper.method,
118                params: helper.params.unwrap_or(Value::default()),
119            })
120        } else {
121            Err(D::Error::custom("Unknown jsonrpc version"))
122        }
123    }
124}
125
126#[derive(Clone, Debug)]
127/// Parses a JSON-RPC request, and returns the request ID, the method name, and the parameters.
128/// If the request is invalid, returns an error.
129/// ```rust
130/// use axum_jrpc::{JrpcResult, JsonRpcExtractor, JsonRpcResponse};
131///
132/// fn router(req: JsonRpcExtractor) -> JrpcResult {
133///   let req_id = req.get_answer_id();
134///   let method = req.method();
135///   match method {
136///     "add" => {
137///        let params: [i32;2] = req.parse_params()?;
138///        return Ok(JsonRpcResponse::success(req_id, params[0] + params[1]));
139///     }
140///     m =>  Ok(req.method_not_found(m))
141///   }
142/// }
143/// ```
144pub struct JsonRpcExtractor {
145    pub parsed: Value,
146    pub method: String,
147    pub id: Id,
148}
149
150impl JsonRpcExtractor {
151    pub fn get_answer_id(&self) -> Id {
152        self.id.clone()
153    }
154
155    pub fn parse_params<T: DeserializeOwned>(self) -> Result<T, JsonRpcResponse> {
156        cfg_if::cfg_if! {
157            if #[cfg(feature = "serde_json")] {
158                match serde_json::from_value(self.parsed){
159                    Ok(v) => Ok(v),
160                    Err(e) => {
161                        let error = JsonRpcError::new(
162                            JsonRpcErrorReason::InvalidParams,
163                            e.to_string(),
164                            Value::Null,
165                        );
166                        Err(JsonRpcResponse::error(self.id, error))
167                    }
168                }
169            } else if #[cfg(feature = "simd")] {
170                match simd_json::serde::from_owned_value(self.parsed){
171                    Ok(v) => Ok(v),
172                    Err(e) => {
173                        let error = JsonRpcError::new(
174                            JsonRpcErrorReason::InvalidParams,
175                            e.to_string(),
176                            Value::default(),
177                        );
178                        Err(JsonRpcResponse::error(self.id, error))
179                    }
180                }
181            }
182        }
183    }
184
185    pub fn method(&self) -> &str {
186        &self.method
187    }
188
189    pub fn method_not_found(&self, method: &str) -> JsonRpcResponse {
190        let error = JsonRpcError::new(
191            JsonRpcErrorReason::MethodNotFound,
192            format!("Method `{}` not found", method),
193            Value::default(),
194        );
195
196        JsonRpcResponse::error(self.id.clone(), error)
197    }
198}
199
200impl<S> FromRequest<S> for JsonRpcExtractor
201where
202    Bytes: FromRequest<S>,
203    S: Send + Sync,
204{
205    type Rejection = JsonRpcResponse;
206
207    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
208        if !json_content_type(req.headers()) {
209            return Err(JsonRpcResponse {
210                id: Id::None(()),
211                result: JsonRpcAnswer::Error(JsonRpcError::new(
212                    JsonRpcErrorReason::InvalidRequest,
213                    "Invalid content type".to_owned(),
214                    Value::default(),
215                )),
216            });
217        }
218
219        #[allow(unused_mut)]
220        let mut bytes = match Bytes::from_request(req, state).await {
221            Ok(a) => a.to_vec(),
222            Err(_) => {
223                return Err(JsonRpcResponse {
224                    id: Id::None(()),
225                    result: JsonRpcAnswer::Error(JsonRpcError::new(
226                        JsonRpcErrorReason::InvalidRequest,
227                        "Invalid request".to_owned(),
228                        Value::default(),
229                    )),
230                })
231            }
232        };
233
234        cfg_if!(
235            if #[cfg(feature = "serde_json")] {
236               let parsed: JsonRpcRequest = match serde_json::from_slice(&bytes){
237                    Ok(a) => a,
238                    Err(e) => {
239                        return Err(JsonRpcResponse {
240                            id: Id::None(()),
241                            result: JsonRpcAnswer::Error(JsonRpcError::new(
242                                JsonRpcErrorReason::InvalidRequest,
243                                e.to_string(),
244                                Value::default(),
245                            )),
246                        })
247                    }
248                };
249            } else if #[cfg(feature = "simd")] {
250               let parsed: JsonRpcRequest = match simd_json::from_slice(&mut bytes){
251                    Ok(a) => a,
252                    Err(e) => {
253                        return Err(JsonRpcResponse {
254                            id: Id::None(()),
255                            result: JsonRpcAnswer::Error(JsonRpcError::new(
256                                JsonRpcErrorReason::InvalidRequest,
257                                e.to_string(),
258                                Value::default(),
259                            )),
260                        })
261                    }
262                };
263            }
264        );
265
266        Ok(Self {
267            parsed: parsed.params,
268            method: parsed.method,
269            id: parsed.id,
270        })
271    }
272}
273
274fn json_content_type(headers: &HeaderMap) -> bool {
275    let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
276        content_type
277    } else {
278        return false;
279    };
280
281    let content_type = if let Ok(content_type) = content_type.to_str() {
282        content_type
283    } else {
284        return false;
285    };
286
287    let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
288        mime
289    } else {
290        return false;
291    };
292
293    let is_json_content_type = mime.type_() == "application"
294        && (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json"));
295
296    is_json_content_type
297}
298
299#[derive(Debug, Clone, PartialEq)]
300/// A JSON-RPC response.
301pub struct JsonRpcResponse {
302    /// Request content.
303    pub result: JsonRpcAnswer,
304    /// The request ID.
305    pub id: Id,
306}
307
308impl JsonRpcResponse {
309    fn new<ID>(id: ID, result: JsonRpcAnswer) -> Self
310    where
311        Id: From<ID>,
312    {
313        Self {
314            result,
315            id: id.into(),
316        }
317    }
318
319    /// Returns a response with the given result
320    /// Returns JsonRpcError if the `result` is invalid input for [`serde_json::to_value`]
321    pub fn success<T, ID>(id: ID, result: T) -> Self
322    where
323        T: Serialize,
324        Id: From<ID>,
325    {
326        cfg_if::cfg_if! {
327          if #[cfg(feature = "serde_json")] {
328            match serde_json::to_value(result) {
329                Ok(v) => JsonRpcResponse::new(id, JsonRpcAnswer::Result(v)),
330                Err(e) => {
331                    let err = JsonRpcError::new(
332                        JsonRpcErrorReason::InternalError,
333                        e.to_string(),
334                        Value::Null,
335                    );
336                    JsonRpcResponse::error(id, err)
337                }
338            }
339          } else if #[cfg(feature = "simd")] {
340            match simd_json::serde::to_owned_value(result) {
341                Ok(v) => JsonRpcResponse::new(id, JsonRpcAnswer::Result(v)),
342                Err(e) => {
343                    let err = JsonRpcError::new(
344                        JsonRpcErrorReason::InternalError,
345                        e.to_string(),
346                        Value::default(),
347                    );
348                    JsonRpcResponse::error(id, err)
349                }
350            }
351          }
352        }
353    }
354
355    pub fn error<ID>(id: ID, error: JsonRpcError) -> Self
356    where
357        Id: From<ID>,
358    {
359        let id = id.into();
360        JsonRpcResponse {
361            result: JsonRpcAnswer::Error(error),
362            id,
363        }
364    }
365}
366
367impl Serialize for JsonRpcResponse {
368    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
369    where
370        S: serde::Serializer,
371    {
372        #[derive(Serialize)]
373        struct Helper<'a> {
374            jsonrpc: &'static str,
375            #[serde(flatten)]
376            result: &'a JsonRpcAnswer,
377            id: Id,
378        }
379
380        Helper {
381            jsonrpc: JSONRPC,
382            result: &self.result,
383            id: self.id.clone(),
384        }
385        .serialize(serializer)
386    }
387}
388
389impl<'de> Deserialize<'de> for JsonRpcResponse {
390    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
391    where
392        D: serde::Deserializer<'de>,
393    {
394        use serde::de::Error;
395
396        #[derive(Deserialize)]
397        struct Helper<'a> {
398            #[serde(borrow)]
399            jsonrpc: Cow<'a, str>,
400            #[serde(flatten)]
401            result: JsonRpcAnswer,
402            id: Id,
403        }
404
405        let helper = Helper::deserialize(deserializer)?;
406        if helper.jsonrpc == JSONRPC {
407            Ok(Self {
408                result: helper.result,
409                id: helper.id,
410            })
411        } else {
412            Err(D::Error::custom("Unknown jsonrpc version"))
413        }
414    }
415}
416
417impl IntoResponse for JsonRpcResponse {
418    fn into_response(self) -> Response {
419        Json(self).into_response()
420    }
421}
422
423#[derive(Serialize, Clone, Debug, Deserialize, PartialEq)]
424#[serde(rename_all = "lowercase")]
425/// JsonRpc [response object](https://www.jsonrpc.org/specification#response_object)
426pub enum JsonRpcAnswer {
427    Result(Value),
428    Error(JsonRpcError),
429}
430
431const JSONRPC: &str = "2.0";
432
433/// An identifier established by the Client that MUST contain a String, Number,
434/// or NULL value if included. If it is not included it is assumed to be a notification.
435/// The value SHOULD normally not be Null and Numbers SHOULD NOT contain fractional parts
436#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Hash)]
437#[serde(untagged)]
438pub enum Id {
439    Num(i64),
440    Str(String),
441    None(()),
442}
443
444impl From<()> for Id {
445    fn from(val: ()) -> Self {
446        Id::None(val)
447    }
448}
449
450impl From<i64> for Id {
451    fn from(val: i64) -> Self {
452        Id::Num(val)
453    }
454}
455
456impl From<String> for Id {
457    fn from(val: String) -> Self {
458        Id::Str(val)
459    }
460}
461
462#[cfg(test)]
463#[cfg(all(feature = "anyhow_error", feature = "serde_json"))]
464mod test {
465    use crate::{
466        Deserialize, JrpcResult, JsonRpcAnswer, JsonRpcError, JsonRpcErrorReason, JsonRpcExtractor,
467        JsonRpcRequest, JsonRpcResponse,
468    };
469    use axum::routing::post;
470    use serde::Serialize;
471    use serde_json::Value;
472
473    #[tokio::test]
474    async fn test() {
475        use axum::http::StatusCode;
476        use axum::Router;
477        use axum_test::TestServer;
478
479        // you can replace this Router with your own app
480        let app = Router::new().route("/", post(handler));
481
482        // initiate the TestClient with the previous declared Router
483        let client = TestServer::new(app).unwrap();
484
485        let res = client
486            .post("/")
487            .json(&JsonRpcRequest {
488                id: 0.into(),
489                method: "add".to_owned(),
490                params: serde_json::to_value(Test { a: 0, b: 111 }).unwrap(),
491            })
492            .await;
493        assert_eq!(res.status_code(), StatusCode::OK);
494        let response = res.json::<JsonRpcResponse>();
495        assert_eq!(response.result, JsonRpcAnswer::Result(111.into()));
496
497        let res = client
498            .post("/")
499            .json(&JsonRpcRequest {
500                id: 0.into(),
501                method: "lol".to_owned(),
502                params: serde_json::to_value(()).unwrap(),
503            })
504            .await;
505
506        assert_eq!(res.status_code(), StatusCode::OK);
507
508        let response = res.json::<JsonRpcResponse>();
509
510        let error = JsonRpcError::new(
511            JsonRpcErrorReason::MethodNotFound,
512            format!("Method `{}` not found", "lol"),
513            Value::Null,
514        );
515
516        let error = JsonRpcResponse::error(0, error);
517
518        assert_eq!(
519            serde_json::to_value(error).unwrap(),
520            serde_json::to_value(response).unwrap()
521        );
522    }
523
524    async fn handler(value: JsonRpcExtractor) -> JrpcResult {
525        let answer_id = value.get_answer_id();
526        println!("{:?}", value);
527        match value.method.as_str() {
528            "add" => {
529                let request: Test = value.parse_params()?;
530                let result = request.a + request.b;
531                Ok(JsonRpcResponse::success(answer_id, result))
532            }
533            "sub" => {
534                let result: [i32; 2] = value.parse_params()?;
535                let result = match failing_sub(result[0], result[1]).await {
536                    Ok(result) => result,
537                    Err(e) => return Err(JsonRpcResponse::error(answer_id, e.into())),
538                };
539                Ok(JsonRpcResponse::success(answer_id, result))
540            }
541            "div" => {
542                let result: [i32; 2] = value.parse_params()?;
543                let result = match failing_div(result[0], result[1]).await {
544                    Ok(result) => result,
545                    Err(e) => return Err(JsonRpcResponse::error(answer_id, e.into())),
546                };
547
548                Ok(JsonRpcResponse::success(answer_id, result))
549            }
550            method => Ok(value.method_not_found(method)),
551        }
552    }
553
554    async fn failing_sub(a: i32, b: i32) -> anyhow::Result<i32> {
555        anyhow::ensure!(a > b, "a must be greater than b");
556        Ok(a - b)
557    }
558
559    async fn failing_div(a: i32, b: i32) -> Result<i32, CustomError> {
560        if b == 0 {
561            Err(CustomError::DivideByZero)
562        } else {
563            Ok(a / b)
564        }
565    }
566
567    #[derive(Deserialize, Serialize, Debug)]
568    struct Test {
569        a: i32,
570        b: i32,
571    }
572
573    #[derive(Debug, thiserror::Error)]
574    enum CustomError {
575        #[error("Divisor must not be equal to 0")]
576        DivideByZero,
577    }
578
579    impl From<CustomError> for JsonRpcError {
580        fn from(error: CustomError) -> Self {
581            JsonRpcError::new(
582                JsonRpcErrorReason::ServerError(-32099),
583                error.to_string(),
584                serde_json::Value::Null,
585            )
586        }
587    }
588}