1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//! Implementation of utilities for working with MessagePack with requests in `salvo` and `reqwest`.

use crate::prelude::*;

#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
use serde::Deserialize;

#[cfg(feature = "reqwest")]
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
use serde::Serialize;

#[cfg(feature = "salvo")]
use salvo::Request;

#[cfg(feature = "reqwest")]
use reqwest::RequestBuilder;

#[cfg(feature = "salvo")]
#[salvo::async_trait]
pub trait MsgPackParser {
  async fn parse_msgpack<'de, T: Deserialize<'de>>(&'de mut self) -> MResult<T>;
  async fn parse_msgpack_with_max_size<'de, T: Deserialize<'de>>(&'de mut self, max_size: usize) -> MResult<T>;
}

#[cfg(feature = "salvo")]
#[salvo::async_trait]
impl MsgPackParser for Request {
  /// Parse MessagePack body as type `T` from request with default max size limit.
  #[inline]
  async fn parse_msgpack<'de, T: Deserialize<'de>>(&'de mut self) -> MResult<T> {
    self.parse_msgpack_with_max_size(salvo::http::request::secure_max_size()).await
  }
  
  /// Parse MessagePack body as type `T` from request with max size limit.
  #[inline]
  async fn parse_msgpack_with_max_size<'de, T: Deserialize<'de>>(&'de mut self, max_size: usize) -> MResult<T> {
    let ctype = self.content_type();
    if let Some(ctype) = ctype {
      if ctype.subtype() == salvo::http::mime::MSGPACK {
        let payload = self.payload_with_max_size(max_size).await?;
        let payload = if payload.is_empty() {
          "null".as_bytes()
        } else {
          payload.as_ref()
        };
        log::debug!("{:?}", payload);
        return Ok(
          rmp_serde::from_slice::<T>(payload).consider(Some(StatusCode::BAD_REQUEST), None, true)?
        );
      }
    }
    Err(
      ErrorResponse {
        status_code: Some(StatusCode::BAD_REQUEST),
        error_text: "Bad content type, must be `application/msgpack`.".into(),
        original_text: None,
        public_error: true
      }
    )
  }
}

#[cfg(feature = "reqwest")]
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
pub trait MsgPackBuilder {
  fn msgpack<T: Serialize + ?Sized>(self, msgpack: &T) -> MResult<RequestBuilder>;
}

#[cfg(feature = "reqwest")]
#[cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))]
impl MsgPackBuilder for RequestBuilder {
  fn msgpack<T: Serialize + ?Sized>(self, msgpack: &T) -> MResult<RequestBuilder> {
    let (cli, mut req) = self.build_split();
    let mut error = None;
    if let Ok(req) = req.as_mut() {
      match rmp_serde::to_vec(msgpack) {
        Ok(body) => {
          if !req.headers().contains_key(reqwest::header::CONTENT_TYPE) {
            req.headers_mut().insert(reqwest::header::CONTENT_TYPE, reqwest::header::HeaderValue::from_static("application/msgpack"));
          }
          *req.body_mut() = Some(body.into());
        },
        Err(err) => { error = Some(err); },
      }
    }
    if let Some(err) = error {
      Err(err.to_string().into())
    } else {
      Ok(RequestBuilder::from_parts(cli, req?))
    }
  }
}

#[cfg(feature = "reqwest")]
#[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))]
pub trait MsgPackBuilder {
  fn msgpack<T: Serialize + ?Sized>(self, msgpack: &T) -> CResult<reqwest::Request>;
}

#[cfg(feature = "reqwest")]
#[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))]
impl MsgPackBuilder for RequestBuilder {
  fn msgpack<T: Serialize + ?Sized>(self, msgpack: &T) -> CResult<reqwest::Request> {
    let mut req = self.build();
    let mut error = None;
    if let Ok(req) = req.as_mut() {
      match rmp_serde::to_vec(msgpack) {
        Ok(body) => {
          if !req.headers().contains_key(reqwest::header::CONTENT_TYPE) {
            req.headers_mut().insert(reqwest::header::CONTENT_TYPE, reqwest::header::HeaderValue::from_static("application/msgpack"));
          }
          *req.body_mut() = Some(body.into());
        },
        Err(err) => { error = Some(err); },
      }
    }
    if let Some(err) = error {
      Err(err.to_string().into())
    } else {
      Ok(req?)
    }
  }
}