bma_jrpc/
lib.rs

1#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
2
3pub use bma_jrpc_derive::rpc_client;
4use futures_lite::io::AsyncReadExt;
5use http::status::StatusCode;
6use isahc::config::Configurable;
7use isahc::{AsyncReadResponseExt, ReadResponseExt, RequestExt};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use std::fmt;
10use std::sync::atomic;
11use std::time::Duration;
12
13const JSONRPC_VER: &str = "2.0";
14const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
15
16const MIME_JSON: &str = "application/json";
17#[cfg(feature = "msgpack")]
18const MIME_MSGPACK: &str = "application/msgpack";
19
20pub trait Encoder: Default {
21    fn encode<P: Serialize>(&self, payload: &P) -> Result<Vec<u8>, Error>;
22    fn decode<'a, R: Deserialize<'a>>(&self, data: &'a [u8]) -> Result<R, Error>;
23    fn mime(&self) -> &'static str;
24}
25
26#[derive(Default)]
27pub struct Json {}
28
29impl Encoder for Json {
30    #[inline]
31    fn encode<P: Serialize>(&self, payload: &P) -> Result<Vec<u8>, Error> {
32        serde_json::to_vec(payload).map_err(Into::into)
33    }
34    #[inline]
35    fn decode<'a, R: Deserialize<'a>>(&self, data: &'a [u8]) -> Result<R, Error> {
36        serde_json::from_slice(data).map_err(Into::into)
37    }
38    #[inline]
39    fn mime(&self) -> &'static str {
40        MIME_JSON
41    }
42}
43
44#[cfg(feature = "msgpack")]
45#[derive(Default)]
46pub struct MsgPack {}
47
48#[cfg(feature = "msgpack")]
49impl Encoder for MsgPack {
50    #[inline]
51    fn encode<P: Serialize>(&self, payload: &P) -> Result<Vec<u8>, Error> {
52        rmp_serde::to_vec_named(payload).map_err(Into::into)
53    }
54    #[inline]
55    fn decode<'a, R: Deserialize<'a>>(&self, data: &'a [u8]) -> Result<R, Error> {
56        rmp_serde::from_slice(data).map_err(Into::into)
57    }
58    #[inline]
59    fn mime(&self) -> &'static str {
60        MIME_MSGPACK
61    }
62}
63
64#[derive(Serialize)]
65struct Request<'a, P> {
66    jsonrpc: &'static str,
67    id: usize,
68    method: &'a str,
69    params: P,
70}
71
72#[derive(Deserialize)]
73struct Response<'a, R> {
74    jsonrpc: &'a str,
75    id: usize,
76    result: Option<R>,
77    error: Option<RpcError>,
78}
79
80#[derive(Deserialize, Debug)]
81#[allow(clippy::module_name_repetitions)]
82pub struct RpcError {
83    code: i16,
84    message: Option<String>,
85}
86
87impl RpcError {
88    #[inline]
89    pub fn code(&self) -> i16 {
90        self.code
91    }
92    #[inline]
93    pub fn message(&self) -> Option<&str> {
94        self.message.as_deref()
95    }
96}
97
98#[inline]
99pub fn http_client(url: &str) -> HttpClient<Json> {
100    HttpClient::<Json>::new(url)
101}
102
103pub struct HttpClient<C>
104where
105    C: Encoder,
106{
107    req_id: atomic::AtomicUsize,
108    url: String,
109    timeout: Duration,
110    encoder: C,
111}
112
113pub trait Rpc {
114    fn call<P: Serialize, R: DeserializeOwned>(&self, method: &str, params: P) -> Result<R, Error>;
115}
116
117impl<C> Rpc for HttpClient<C>
118where
119    C: Encoder,
120{
121    fn call<P, R>(&self, method: &str, params: P) -> Result<R, Error>
122    where
123        P: Serialize,
124        R: DeserializeOwned,
125    {
126        let (http_request, id) = self.prepare_http_request(method, params)?;
127        let mut http_response = http_request.send()?;
128        if http_response.status() == StatusCode::OK {
129            self.parse_response(&http_response.bytes()?, id)
130        } else {
131            Err(Error::Http(http_response.status(), http_response.text()?))
132        }
133    }
134}
135
136impl<C> HttpClient<C>
137where
138    C: Encoder,
139{
140    #[inline]
141    pub fn new(url: &str) -> Self {
142        Self {
143            url: url.to_owned(),
144            timeout: DEFAULT_TIMEOUT,
145            req_id: atomic::AtomicUsize::new(0),
146            encoder: C::default(),
147        }
148    }
149    #[inline]
150    pub fn timeout(mut self, timeout: Duration) -> Self {
151        self.timeout = timeout;
152        self
153    }
154    #[inline]
155    fn prepare_http_request<'a, P: Serialize>(
156        &'a self,
157        method: &'a str,
158        params: P,
159    ) -> Result<(isahc::Request<Vec<u8>>, usize), Error> {
160        let req = Request {
161            jsonrpc: JSONRPC_VER,
162            id: self.req_id.fetch_add(1, atomic::Ordering::SeqCst),
163            method,
164            params,
165        };
166        let payload = self.encoder.encode(&req)?;
167        Ok((
168            isahc::Request::post(&self.url)
169                .timeout(self.timeout)
170                .header("content-type", self.encoder.mime())
171                .body(payload)?,
172            req.id,
173        ))
174    }
175    pub async fn call_async<P, R>(&self, method: &str, params: P) -> Result<R, Error>
176    where
177        P: Serialize,
178        R: DeserializeOwned,
179    {
180        let (http_request, id) = self.prepare_http_request(method, params)?;
181        let mut resp = http_request.send_async().await?;
182        if resp.status() == StatusCode::OK {
183            let mut buf =
184                Vec::with_capacity(usize::try_from(resp.body().len().unwrap_or_default())?);
185            resp.body_mut().read_to_end(&mut buf).await?;
186            self.parse_response(&buf, id)
187        } else {
188            Err(Error::Http(resp.status(), resp.text().await?))
189        }
190    }
191    fn parse_response<'a, R: Deserialize<'a>>(&self, buf: &'a [u8], id: usize) -> Result<R, Error> {
192        let resp: Response<R> = self.encoder.decode(buf)?;
193        if resp.jsonrpc != JSONRPC_VER {
194            return Err(Error::Protocol("invalid JSON RPC version"));
195        }
196        if resp.id != id {
197            return Err(Error::Protocol("invalid response ID"));
198        }
199        if let Some(err) = resp.error {
200            Err(Error::Rpc(err))
201        } else if let Some(result) = resp.result {
202            Ok(result)
203        } else {
204            Err(Error::Protocol("no result/error fields"))
205        }
206    }
207}
208
209#[derive(Debug)]
210pub enum Error {
211    Protocol(&'static str),
212    Rpc(RpcError),
213    Transport(isahc::Error),
214    Http(StatusCode, String),
215    Other(Box<dyn std::error::Error + Send + Sync>),
216}
217
218impl fmt::Display for Error {
219    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220        match self {
221            Error::Protocol(s) => write!(f, "invalid server response: {}", s),
222            Error::Rpc(e) => write!(f, "{} {}", e.code, e.message.as_deref().unwrap_or_default()),
223            Error::Transport(s) => write!(f, "{}", s),
224            Error::Http(code, s) => write!(f, "{} {}", code, s),
225            Error::Other(e) => write!(f, "{}", e),
226        }
227    }
228}
229
230impl std::error::Error for Error {}
231
232macro_rules! impl_other_err {
233    ($t: ty) => {
234        impl From<$t> for Error {
235            fn from(err: $t) -> Self {
236                Self::Other(Box::new(err))
237            }
238        }
239    };
240}
241
242impl From<isahc::http::Error> for Error {
243    fn from(err: isahc::http::Error) -> Self {
244        Self::Transport(err.into())
245    }
246}
247
248impl From<isahc::Error> for Error {
249    fn from(err: isahc::Error) -> Self {
250        Self::Transport(err)
251    }
252}
253
254impl_other_err!(serde_json::Error);
255#[cfg(feature = "msgpack")]
256impl_other_err!(rmp_serde::decode::Error);
257#[cfg(feature = "msgpack")]
258impl_other_err!(rmp_serde::encode::Error);
259impl_other_err!(std::io::Error);
260impl_other_err!(std::num::TryFromIntError);