json_rpc_server/
server.rs

1use std::{fmt::Debug, future::Future, net::SocketAddr, pin::Pin, sync::Arc};
2
3use anyhow::{anyhow, Result};
4use async_trait::async_trait;
5use bytes::Bytes;
6use http_body_util::{BodyExt, Full};
7use hyper::{
8    body::Incoming,
9    server::conn::http1,
10    service::{service_fn, Service},
11    Request, Response,
12};
13use hyper_util::rt::TokioIo;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use tokio::net::TcpListener;
17
18use crate::{RPCError, RPCRequest, RPCResponse};
19
20#[async_trait]
21pub trait Handle {
22    type Request: for<'de> Deserialize<'de> + Send + Sync + Clone + 'static;
23    type Response: Serialize + Send;
24
25    async fn handle(
26        &self,
27        method: &str,
28        req: Option<Self::Request>,
29    ) -> std::result::Result<Option<Self::Response>, RPCError>;
30
31    async fn batch_handle(
32        &self,
33        reqests: Vec<RPCRequest<Option<Self::Request>>>,
34    ) -> Vec<RPCResponse<Self::Response>> {
35        let mut response = vec![];
36        for reqest in reqests {
37            let resp = self
38                .handle(&reqest.method, reqest.params)
39                .await
40                .map_or_else(
41                    |e| RPCResponse::error(reqest.id.clone(), e),
42                    |v| RPCResponse::result(reqest.id.clone(), v),
43                );
44            response.push(resp);
45        }
46        response
47    }
48}
49
50async fn _handle<H>(req_body: serde_json::Value, handle: &H) -> Result<serde_json::Value>
51where
52    H: Handle,
53    H::Request: Debug,
54{
55    let req: RPCRequest<Option<H::Request>> = serde_json::from_value(req_body)?;
56
57    log::info!("Get call method: {}", &req.method);
58    log::debug!("Params is: {:?}", &req.params);
59
60    let r = match handle.handle(&req.method, req.params).await {
61        Ok(v) => RPCResponse::result(req.id, v),
62        Err(e) => RPCResponse::error(req.id, e),
63    };
64
65    r.into_value()
66}
67async fn _batch_handle<H>(req_body: serde_json::Value, handle: &H) -> Result<serde_json::Value>
68where
69    H: Handle + Sync,
70    H::Request: Debug,
71{
72    let req: Vec<RPCRequest<Option<H::Request>>> = serde_json::from_value(req_body)?;
73
74    log::debug!("Batch params is: {:?}", &req);
75
76    let r = handle.batch_handle(req).await;
77
78    Ok(serde_json::to_value(r)?)
79}
80
81struct HandleHttp<H> {
82    handle: Arc<H>,
83}
84
85impl<H> Service<Request<Incoming>> for HandleHttp<H>
86where
87    H: Handle + Send + Sync + 'static,
88    H::Request: Debug,
89{
90    type Response = Response<Full<Bytes>>;
91    type Error = anyhow::Error;
92    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
93
94    fn call(&self, request: Request<Incoming>) -> Self::Future {
95        let handle = self.handle.clone();
96
97        let r = async move {
98            let req_body = request
99                .into_body()
100                .collect()
101                .await
102                .map_err(|e| anyhow!("{e}"))?
103                .to_bytes()
104                .to_vec();
105
106            log::debug!("Request Body: {:?}", req_body);
107
108            let req_body = serde_json::from_slice::<Value>(&req_body)?;
109
110            let body = if req_body.is_object() {
111                let r = _handle(req_body, handle.as_ref()).await?;
112                serde_json::to_string(&r)?
113            } else if req_body.is_array() {
114                let r = _batch_handle(req_body, handle.as_ref()).await?;
115                serde_json::to_string(&r)?
116            } else {
117                return Err(anyhow!("Unsupport type"));
118            };
119            log::debug!("Response Body: {:?}", body);
120
121            let resp = Response::builder()
122                .header("Content-Type", "application/json")
123                .body(Full::new(Bytes::from(body)))?;
124
125            Ok(resp)
126        };
127
128        Box::pin(r)
129    }
130}
131
132pub async fn serve<H>(addr: &SocketAddr, handle: H) -> Result<()>
133where
134    H: Handle + Send + Sync + 'static,
135    H::Request: Debug,
136{
137    let listener = TcpListener::bind(addr).await?;
138    println!("Listening on http://{}", addr);
139
140    let handle = Arc::new(handle);
141
142    loop {
143        let (stream, _) = listener.accept().await?;
144        let io = TokioIo::new(stream);
145
146        let handle = handle.clone();
147        let service = service_fn(move |req| {
148            let value = handle.clone();
149            async move { HandleHttp { handle: value }.call(req).await }
150        });
151
152        if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
153            println!("Error serving connection: {:?}", err);
154        }
155    }
156}