axum_jrpc/
lib.rs

1#![warn(
2    clippy::all,
3    clippy::dbg_macro,
4    clippy::todo,
5    clippy::empty_enum,
6    clippy::enum_glob_use,
7    clippy::mem_forget,
8    clippy::unused_self,
9    clippy::filter_map_next,
10    clippy::needless_continue,
11    clippy::needless_borrow,
12    clippy::match_wildcard_for_single_variants,
13    clippy::if_let_mutex,
14    unexpected_cfgs,
15    clippy::await_holding_lock,
16    clippy::match_on_vec_items,
17    clippy::imprecise_flops,
18    clippy::suboptimal_flops,
19    clippy::lossy_float_literal,
20    clippy::rest_pat_in_fully_bound_structs,
21    clippy::fn_params_excessive_bools,
22    clippy::exit,
23    clippy::inefficient_to_string,
24    clippy::linkedlist,
25    clippy::macro_use_imports,
26    clippy::option_option,
27    clippy::verbose_file_reads,
28    clippy::unnested_or_patterns,
29    clippy::str_to_string,
30    rust_2018_idioms,
31    future_incompatible,
32    nonstandard_style,
33    missing_debug_implementations
34)]
35#![deny(unreachable_pub)]
36#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
37
38use std::borrow::Cow;
39
40use axum::body::Bytes;
41use axum::extract::{FromRequest, Request};
42use axum::http::{header, HeaderMap};
43use axum::response::{IntoResponse, Response};
44use axum::Json;
45use cfg_if::cfg_if;
46use serde::de::DeserializeOwned;
47use serde::{Deserialize, Serialize};
48
49cfg_if! {
50    if #[cfg(feature = "serde_json")] {
51        pub use serde_json::Value;
52        pub mod error;
53        use crate::error::{JsonRpcError, JsonRpcErrorReason};
54    }
55    else if #[cfg(feature = "simd")] {
56        pub use simd_json::OwnedValue as Value;
57        pub mod error;
58        use crate::error::{JsonRpcError, JsonRpcErrorReason};
59    }
60    else {
61        compile_error!("features `serde_json` and `simd` are mutually exclusive");
62    }
63}
64
65/// Hack until [try_trait_v2](https://github.com/rust-lang/rust/issues/84277) is not stabilized
66pub type JrpcResult = Result<JsonRpcResponse, JsonRpcResponse>;
67
68#[derive(Debug)]
69pub struct JsonRpcRequest {
70    pub id: Id,
71    pub method: String,
72    pub params: Value,
73}
74
75impl Serialize for JsonRpcRequest {
76    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: serde::Serializer,
79    {
80        #[derive(Serialize)]
81        struct Helper<'a> {
82            jsonrpc: &'static str,
83            id: Id,
84            method: &'a str,
85            params: &'a Value,
86        }
87
88        Helper {
89            jsonrpc: JSONRPC,
90            id: self.id.clone(),
91            method: &self.method,
92            params: &self.params,
93        }
94        .serialize(serializer)
95    }
96}
97
98impl<'de> Deserialize<'de> for JsonRpcRequest {
99    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100    where
101        D: serde::Deserializer<'de>,
102    {
103        use serde::de::Error;
104
105        #[derive(Deserialize)]
106        struct Helper<'a> {
107            #[serde(borrow)]
108            jsonrpc: Cow<'a, str>,
109            id: Id,
110            method: String,
111            params: Value,
112        }
113
114        let helper = Helper::deserialize(deserializer)?;
115        if helper.jsonrpc == JSONRPC {
116            Ok(Self {
117                id: helper.id,
118                method: helper.method,
119                params: helper.params,
120            })
121        } else {
122            Err(D::Error::custom("Unknown jsonrpc version"))
123        }
124    }
125}
126
127#[derive(Clone, Debug)]
128/// Parses a JSON-RPC request, and returns the request ID, the method name, and the parameters.
129/// If the request is invalid, returns an error.
130/// ```rust
131/// use axum_jrpc::{JrpcResult, JsonRpcExtractor, JsonRpcResponse};
132///
133/// fn router(req: JsonRpcExtractor) -> JrpcResult {
134///   let req_id = req.get_answer_id();
135///   let method = req.method();
136///   match method {
137///     "add" => {
138///        let params: [i32;2] = req.parse_params()?;
139///        return Ok(JsonRpcResponse::success(req_id, params[0] + params[1]));
140///     }
141///     m =>  Ok(req.method_not_found(m))
142///   }
143/// }
144/// ```
145pub struct JsonRpcExtractor {
146    pub parsed: Value,
147    pub method: String,
148    pub id: Id,
149}
150
151impl JsonRpcExtractor {
152    pub fn get_answer_id(&self) -> Id {
153        self.id.clone()
154    }
155
156    pub fn parse_params<T: DeserializeOwned>(self) -> Result<T, JsonRpcResponse> {
157        cfg_if::cfg_if! {
158           if #[cfg(feature = "simd")] {
159                match simd_json::serde::from_owned_value(self.parsed){
160                    Ok(v) => Ok(v),
161                    Err(e) => {
162                        let error = JsonRpcError::new(
163                            JsonRpcErrorReason::InvalidParams,
164                            e.to_string(),
165                            Value::default(),
166                        );
167                        Err(JsonRpcResponse::error(self.id, error))
168                    }
169
170                }
171            } else if #[cfg(feature = "serde_json")] {
172                match serde_json::from_value(self.parsed){
173                    Ok(v) => Ok(v),
174                    Err(e) => {
175                        let error = JsonRpcError::new(
176                            JsonRpcErrorReason::InvalidParams,
177                            e.to_string(),
178                            Value::Null,
179                        );
180                        Err(JsonRpcResponse::error(self.id, error))
181                    }
182                }
183            }
184        }
185    }
186
187    pub fn method(&self) -> &str {
188        &self.method
189    }
190
191    pub fn method_not_found(&self, method: &str) -> JsonRpcResponse {
192        let error = JsonRpcError::new(
193            JsonRpcErrorReason::MethodNotFound,
194            format!("Method `{}` not found", method),
195            Value::default(),
196        );
197
198        JsonRpcResponse::error(self.id.clone(), error)
199    }
200}
201
202impl<S> FromRequest<S> for JsonRpcExtractor
203where
204    Bytes: FromRequest<S>,
205    S: Send + Sync,
206{
207    type Rejection = JsonRpcResponse;
208
209    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
210        if !json_content_type(req.headers()) {
211            return Err(JsonRpcResponse {
212                id: Id::None(()),
213                result: JsonRpcAnswer::Error(JsonRpcError::new(
214                    JsonRpcErrorReason::InvalidRequest,
215                    "Invalid content type".to_owned(),
216                    Value::default(),
217                )),
218            });
219        }
220
221        #[allow(unused_mut)]
222        let mut bytes = match Bytes::from_request(req, state).await {
223            Ok(a) => a.to_vec(),
224            Err(_) => {
225                return Err(JsonRpcResponse {
226                    id: Id::None(()),
227                    result: JsonRpcAnswer::Error(JsonRpcError::new(
228                        JsonRpcErrorReason::InvalidRequest,
229                        "Invalid request".to_owned(),
230                        Value::default(),
231                    )),
232                })
233            }
234        };
235
236        cfg_if!(
237            if #[cfg(feature = "simd")] {
238               let parsed: JsonRpcRequest = match simd_json::from_slice(&mut bytes){
239                    Ok(a) => a,
240                    Err(e) => {
241                        return Err(JsonRpcResponse {
242                            id: Id::None(()),
243                            result: JsonRpcAnswer::Error(JsonRpcError::new(
244                                JsonRpcErrorReason::InvalidRequest,
245                                e.to_string(),
246                                Value::default(),
247                            )),
248                        })
249                    }
250                };
251            } else if #[cfg(feature = "serde_json")] {
252               let parsed: JsonRpcRequest = match serde_json::from_slice(&bytes){
253                    Ok(a) => a,
254                    Err(e) => {
255                        return Err(JsonRpcResponse {
256                            id: Id::None(()),
257                            result: JsonRpcAnswer::Error(JsonRpcError::new(
258                                JsonRpcErrorReason::InvalidRequest,
259                                e.to_string(),
260                                Value::default(),
261                            )),
262                        })
263                    }
264                };
265            }
266        );
267
268        Ok(Self {
269            parsed: parsed.params,
270            method: parsed.method,
271            id: parsed.id,
272        })
273    }
274}
275
276fn json_content_type(headers: &HeaderMap) -> bool {
277    let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
278        content_type
279    } else {
280        return false;
281    };
282
283    let content_type = if let Ok(content_type) = content_type.to_str() {
284        content_type
285    } else {
286        return false;
287    };
288
289    let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
290        mime
291    } else {
292        return false;
293    };
294
295    let is_json_content_type = mime.type_() == "application"
296        && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json"));
297
298    is_json_content_type
299}
300
301#[derive(Debug, Clone, PartialEq)]
302/// A JSON-RPC response.
303pub struct JsonRpcResponse {
304    /// Request content.
305    pub result: JsonRpcAnswer,
306    /// The request ID.
307    pub id: Id,
308}
309
310impl JsonRpcResponse {
311    fn new<ID>(id: ID, result: JsonRpcAnswer) -> Self
312    where
313        Id: From<ID>,
314    {
315        Self {
316            result,
317            id: id.into(),
318        }
319    }
320
321    /// Returns a response with the given result
322    /// Returns JsonRpcError if the `result` is invalid input for [`serde_json::to_value`]
323    pub fn success<T, ID>(id: ID, result: T) -> Self
324    where
325        T: Serialize,
326        Id: From<ID>,
327    {
328        cfg_if::cfg_if! {
329          if #[cfg(feature = "simd")] {
330            match simd_json::serde::to_owned_value(result) {
331                Ok(v) => JsonRpcResponse::new(id, JsonRpcAnswer::Result(v)),
332                Err(e) => {
333                    let err = JsonRpcError::new(
334                        JsonRpcErrorReason::InternalError,
335                        e.to_string(),
336                        Value::default(),
337                    );
338                    JsonRpcResponse::error(id, err)
339                }
340            }
341          } else if #[cfg(feature = "serde_json")] {
342            match serde_json::to_value(result) {
343                Ok(v) => JsonRpcResponse::new(id, JsonRpcAnswer::Result(v)),
344                Err(e) => {
345                    let err = JsonRpcError::new(
346                        JsonRpcErrorReason::InternalError,
347                        e.to_string(),
348                        Value::Null,
349                    );
350                    JsonRpcResponse::error(id, err)
351                }
352            }
353          }
354        }
355    }
356
357    pub fn error<ID>(id: ID, error: JsonRpcError) -> Self
358    where
359        Id: From<ID>,
360    {
361        let id = id.into();
362        JsonRpcResponse {
363            result: JsonRpcAnswer::Error(error),
364            id,
365        }
366    }
367}
368
369impl Serialize for JsonRpcResponse {
370    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
371    where
372        S: serde::Serializer,
373    {
374        #[derive(Serialize)]
375        struct Helper<'a> {
376            jsonrpc: &'static str,
377            #[serde(flatten)]
378            result: &'a JsonRpcAnswer,
379            id: Id,
380        }
381
382        Helper {
383            jsonrpc: JSONRPC,
384            result: &self.result,
385            id: self.id.clone(),
386        }
387        .serialize(serializer)
388    }
389}
390
391impl<'de> Deserialize<'de> for JsonRpcResponse {
392    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
393    where
394        D: serde::Deserializer<'de>,
395    {
396        use serde::de::Error;
397
398        #[derive(Deserialize)]
399        struct Helper<'a> {
400            #[serde(borrow)]
401            jsonrpc: Cow<'a, str>,
402            #[serde(flatten)]
403            result: JsonRpcAnswer,
404            id: Id,
405        }
406
407        let helper = Helper::deserialize(deserializer)?;
408        if helper.jsonrpc == JSONRPC {
409            Ok(Self {
410                result: helper.result,
411                id: helper.id,
412            })
413        } else {
414            Err(D::Error::custom("Unknown jsonrpc version"))
415        }
416    }
417}
418
419impl IntoResponse for JsonRpcResponse {
420    fn into_response(self) -> Response {
421        Json(self).into_response()
422    }
423}
424
425#[derive(Serialize, Clone, Debug, Deserialize, PartialEq)]
426#[serde(rename_all = "lowercase")]
427/// JsonRpc [response object](https://www.jsonrpc.org/specification#response_object)
428pub enum JsonRpcAnswer {
429    Result(Value),
430    Error(JsonRpcError),
431}
432
433const JSONRPC: &str = "2.0";
434
435/// An identifier established by the Client that MUST contain a String, Number,
436/// or NULL value if included. If it is not included it is assumed to be a notification.
437/// The value SHOULD normally not be Null and Numbers SHOULD NOT contain fractional parts
438#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize, Hash)]
439#[serde(untagged)]
440pub enum Id {
441    Num(i64),
442    Str(String),
443    None(()),
444}
445
446impl From<()> for Id {
447    fn from(val: ()) -> Self {
448        Id::None(val)
449    }
450}
451
452impl From<i64> for Id {
453    fn from(val: i64) -> Self {
454        Id::Num(val)
455    }
456}
457
458impl From<String> for Id {
459    fn from(val: String) -> Self {
460        Id::Str(val)
461    }
462}
463
464#[cfg(test)]
465#[cfg(all(feature = "anyhow_error", feature = "serde_json"))]
466mod test {
467    use crate::{
468        Deserialize, JrpcResult, JsonRpcAnswer, JsonRpcError, JsonRpcErrorReason, JsonRpcExtractor,
469        JsonRpcRequest, JsonRpcResponse,
470    };
471    use axum::routing::post;
472    use serde::Serialize;
473    use serde_json::Value;
474
475    #[tokio::test]
476    async fn test() {
477        use axum::http::StatusCode;
478        use axum::Router;
479        use axum_test::TestServer;
480
481        // you can replace this Router with your own app
482        let app = Router::new().route("/", post(handler));
483
484        // initiate the TestClient with the previous declared Router
485        let client = TestServer::new(app).unwrap();
486
487        let res = client
488            .post("/")
489            .json(&JsonRpcRequest {
490                id: 0.into(),
491                method: "add".to_owned(),
492                params: serde_json::to_value(Test { a: 0, b: 111 }).unwrap(),
493            })
494            .await;
495        assert_eq!(res.status_code(), StatusCode::OK);
496        let response = res.json::<JsonRpcResponse>();
497        assert_eq!(response.result, JsonRpcAnswer::Result(111.into()));
498
499        let res = client
500            .post("/")
501            .json(&JsonRpcRequest {
502                id: 0.into(),
503                method: "lol".to_owned(),
504                params: serde_json::to_value(()).unwrap(),
505            })
506            .await;
507
508        assert_eq!(res.status_code(), StatusCode::OK);
509
510        let response = res.json::<JsonRpcResponse>();
511
512        let error = JsonRpcError::new(
513            JsonRpcErrorReason::MethodNotFound,
514            format!("Method `{}` not found", "lol"),
515            Value::Null,
516        );
517
518        let error = JsonRpcResponse::error(0, error);
519
520        assert_eq!(
521            serde_json::to_value(error).unwrap(),
522            serde_json::to_value(response).unwrap()
523        );
524    }
525
526    async fn handler(value: JsonRpcExtractor) -> JrpcResult {
527        let answer_id = value.get_answer_id();
528        println!("{:?}", value);
529        match value.method.as_str() {
530            "add" => {
531                let request: Test = value.parse_params()?;
532                let result = request.a + request.b;
533                Ok(JsonRpcResponse::success(answer_id, result))
534            }
535            "sub" => {
536                let result: [i32; 2] = value.parse_params()?;
537                let result = match failing_sub(result[0], result[1]).await {
538                    Ok(result) => result,
539                    Err(e) => return Err(JsonRpcResponse::error(answer_id, e.into())),
540                };
541                Ok(JsonRpcResponse::success(answer_id, result))
542            }
543            "div" => {
544                let result: [i32; 2] = value.parse_params()?;
545                let result = match failing_div(result[0], result[1]).await {
546                    Ok(result) => result,
547                    Err(e) => return Err(JsonRpcResponse::error(answer_id, e.into())),
548                };
549
550                Ok(JsonRpcResponse::success(answer_id, result))
551            }
552            method => Ok(value.method_not_found(method)),
553        }
554    }
555
556    async fn failing_sub(a: i32, b: i32) -> anyhow::Result<i32> {
557        anyhow::ensure!(a > b, "a must be greater than b");
558        Ok(a - b)
559    }
560
561    async fn failing_div(a: i32, b: i32) -> Result<i32, CustomError> {
562        if b == 0 {
563            Err(CustomError::DivideByZero)
564        } else {
565            Ok(a / b)
566        }
567    }
568
569    #[derive(Deserialize, Serialize, Debug)]
570    struct Test {
571        a: i32,
572        b: i32,
573    }
574
575    #[derive(Debug, thiserror::Error)]
576    enum CustomError {
577        #[error("Divisor must not be equal to 0")]
578        DivideByZero,
579    }
580
581    impl From<CustomError> for JsonRpcError {
582        fn from(error: CustomError) -> Self {
583            JsonRpcError::new(
584                JsonRpcErrorReason::ServerError(-32099),
585                error.to_string(),
586                serde_json::Value::Null,
587            )
588        }
589    }
590}