json_rpc_server/
server.rs1use std::{fmt::Debug, future::Future, net::SocketAddr, pin::Pin, sync::Arc};
2
3use anyhow::{anyhow, Result};
4use async_trait::async_trait;
5use bytes::Bytes;
6use http_body_util::{BodyExt, Full};
7use hyper::{
8 body::Incoming,
9 server::conn::http1,
10 service::{service_fn, Service},
11 Request, Response,
12};
13use hyper_util::rt::TokioIo;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use tokio::net::TcpListener;
17
18use crate::{RPCError, RPCRequest, RPCResponse};
19
20#[async_trait]
21pub trait Handle {
22 type Request: for<'de> Deserialize<'de> + Send + Sync + Clone + 'static;
23 type Response: Serialize + Send;
24
25 async fn handle(
26 &self,
27 method: &str,
28 req: Option<Self::Request>,
29 ) -> std::result::Result<Option<Self::Response>, RPCError>;
30
31 async fn batch_handle(
32 &self,
33 reqests: Vec<RPCRequest<Option<Self::Request>>>,
34 ) -> Vec<RPCResponse<Self::Response>> {
35 let mut response = vec![];
36 for reqest in reqests {
37 let resp = self
38 .handle(&reqest.method, reqest.params)
39 .await
40 .map_or_else(
41 |e| RPCResponse::error(reqest.id.clone(), e),
42 |v| RPCResponse::result(reqest.id.clone(), v),
43 );
44 response.push(resp);
45 }
46 response
47 }
48}
49
50async fn _handle<H>(req_body: serde_json::Value, handle: &H) -> Result<serde_json::Value>
51where
52 H: Handle,
53 H::Request: Debug,
54{
55 let req: RPCRequest<Option<H::Request>> = serde_json::from_value(req_body)?;
56
57 log::info!("Get call method: {}", &req.method);
58 log::debug!("Params is: {:?}", &req.params);
59
60 let r = match handle.handle(&req.method, req.params).await {
61 Ok(v) => RPCResponse::result(req.id, v),
62 Err(e) => RPCResponse::error(req.id, e),
63 };
64
65 r.into_value()
66}
67async fn _batch_handle<H>(req_body: serde_json::Value, handle: &H) -> Result<serde_json::Value>
68where
69 H: Handle + Sync,
70 H::Request: Debug,
71{
72 let req: Vec<RPCRequest<Option<H::Request>>> = serde_json::from_value(req_body)?;
73
74 log::debug!("Batch params is: {:?}", &req);
75
76 let r = handle.batch_handle(req).await;
77
78 Ok(serde_json::to_value(r)?)
79}
80
81struct HandleHttp<H> {
82 handle: Arc<H>,
83}
84
85impl<H> Service<Request<Incoming>> for HandleHttp<H>
86where
87 H: Handle + Send + Sync + 'static,
88 H::Request: Debug,
89{
90 type Response = Response<Full<Bytes>>;
91 type Error = anyhow::Error;
92 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
93
94 fn call(&self, request: Request<Incoming>) -> Self::Future {
95 let handle = self.handle.clone();
96
97 let r = async move {
98 let req_body = request
99 .into_body()
100 .collect()
101 .await
102 .map_err(|e| anyhow!("{e}"))?
103 .to_bytes()
104 .to_vec();
105
106 log::debug!("Request Body: {:?}", req_body);
107
108 let req_body = serde_json::from_slice::<Value>(&req_body)?;
109
110 let body = if req_body.is_object() {
111 let r = _handle(req_body, handle.as_ref()).await?;
112 serde_json::to_string(&r)?
113 } else if req_body.is_array() {
114 let r = _batch_handle(req_body, handle.as_ref()).await?;
115 serde_json::to_string(&r)?
116 } else {
117 return Err(anyhow!("Unsupport type"));
118 };
119 log::debug!("Response Body: {:?}", body);
120
121 let resp = Response::builder()
122 .header("Content-Type", "application/json")
123 .body(Full::new(Bytes::from(body)))?;
124
125 Ok(resp)
126 };
127
128 Box::pin(r)
129 }
130}
131
132pub async fn serve<H>(addr: &SocketAddr, handle: H) -> Result<()>
133where
134 H: Handle + Send + Sync + 'static,
135 H::Request: Debug,
136{
137 let listener = TcpListener::bind(addr).await?;
138 println!("Listening on http://{}", addr);
139
140 let handle = Arc::new(handle);
141
142 loop {
143 let (stream, _) = listener.accept().await?;
144 let io = TokioIo::new(stream);
145
146 let handle = handle.clone();
147 let service = service_fn(move |req| {
148 let value = handle.clone();
149 async move { HandleHttp { handle: value }.call(req).await }
150 });
151
152 if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
153 println!("Error serving connection: {:?}", err);
154 }
155 }
156}