rpc_toolkit/server/
http.rs

1use axum::body::Body;
2use axum::extract::Request;
3use axum::handler::Handler;
4use axum::response::Response;
5use futures::future::{join_all, BoxFuture};
6use futures::{Future, FutureExt};
7use http::header::{CONTENT_LENGTH, CONTENT_TYPE};
8use http_body_util::BodyExt;
9use imbl_value::imbl::Vector;
10use imbl_value::Value;
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use yajrc::{RpcError, RpcMethod};
14
15use crate::server::{RpcRequest, RpcResponse, SingleOrBatchRpcRequest};
16use crate::util::{internal_error, parse_error};
17use crate::{HandleAny, Server};
18
19const FALLBACK_ERROR: &str = "{\"error\":{\"code\":-32603,\"message\":\"Internal error\",\"data\":\"Failed to serialize rpc response\"}}";
20
21pub fn fallback_rpc_error_response() -> Response {
22    Response::builder()
23        .header(CONTENT_TYPE, "application/json")
24        .header(CONTENT_LENGTH, FALLBACK_ERROR.len())
25        .body(Body::from(FALLBACK_ERROR.as_bytes()))
26        .unwrap()
27}
28
29pub fn json_http_response<T: Serialize>(t: &T) -> Response {
30    let body = match serde_json::to_vec(t) {
31        Ok(a) => a,
32        Err(_) => return fallback_rpc_error_response(),
33    };
34    Response::builder()
35        .header(CONTENT_TYPE, "application/json")
36        .header(CONTENT_LENGTH, body.len())
37        .body(Body::from(body))
38        .unwrap_or_else(|_| fallback_rpc_error_response())
39}
40
41pub trait Middleware<Context: Send + 'static>: Clone + Send + Sync + 'static {
42    type Metadata: DeserializeOwned + Send + 'static;
43    #[allow(unused_variables)]
44    fn process_http_request(
45        &mut self,
46        context: &Context,
47        request: &mut Request,
48    ) -> impl Future<Output = Result<(), Response>> + Send {
49        async { Ok(()) }
50    }
51    #[allow(unused_variables)]
52    fn process_rpc_request(
53        &mut self,
54        context: &Context,
55        metadata: Self::Metadata,
56        request: &mut RpcRequest,
57    ) -> impl Future<Output = Result<(), RpcResponse>> + Send {
58        async { Ok(()) }
59    }
60    #[allow(unused_variables)]
61    fn process_rpc_response(
62        &mut self,
63        context: &Context,
64        response: &mut RpcResponse,
65    ) -> impl Future<Output = ()> + Send {
66        async { () }
67    }
68    #[allow(unused_variables)]
69    fn process_http_response(
70        &mut self,
71        context: &Context,
72        response: &mut Response,
73    ) -> impl Future<Output = ()> + Send {
74        async { () }
75    }
76}
77
78#[allow(private_bounds)]
79trait _Middleware<Context>: Send + Sync {
80    fn dyn_clone(&self) -> DynMiddleware<Context>;
81    fn process_http_request<'a>(
82        &'a mut self,
83        context: &'a Context,
84        request: &'a mut Request,
85    ) -> BoxFuture<'a, Result<(), Response>>;
86    fn process_rpc_request<'a>(
87        &'a mut self,
88        context: &'a Context,
89        metadata: Value,
90        request: &'a mut RpcRequest,
91    ) -> BoxFuture<'a, Result<(), RpcResponse>>;
92    fn process_rpc_response<'a>(
93        &'a mut self,
94
95        context: &'a Context,
96        response: &'a mut RpcResponse,
97    ) -> BoxFuture<'a, ()>;
98    fn process_http_response<'a>(
99        &'a mut self,
100        context: &'a Context,
101        response: &'a mut Response,
102    ) -> BoxFuture<'a, ()>;
103}
104impl<Context: Send + 'static, T: Middleware<Context> + Send + Sync> _Middleware<Context> for T {
105    fn dyn_clone(&self) -> DynMiddleware<Context> {
106        DynMiddleware(Box::new(<Self as Clone>::clone(&self)))
107    }
108    fn process_http_request<'a>(
109        &'a mut self,
110        context: &'a Context,
111        request: &'a mut Request,
112    ) -> BoxFuture<'a, Result<(), Response>> {
113        <Self as Middleware<Context>>::process_http_request(self, context, request).boxed()
114    }
115    fn process_rpc_request<'a>(
116        &'a mut self,
117        context: &'a Context,
118        metadata: Value,
119        request: &'a mut RpcRequest,
120    ) -> BoxFuture<'a, Result<(), RpcResponse>> {
121        <Self as Middleware<Context>>::process_rpc_request(
122            self,
123            context,
124            match imbl_value::from_value(metadata) {
125                Ok(a) => a,
126                Err(e) => return async { Err(internal_error(e).into()) }.boxed(),
127            },
128            request,
129        )
130        .boxed()
131    }
132    fn process_rpc_response<'a>(
133        &'a mut self,
134        context: &'a Context,
135        response: &'a mut RpcResponse,
136    ) -> BoxFuture<'a, ()> {
137        <Self as Middleware<Context>>::process_rpc_response(self, context, response).boxed()
138    }
139    fn process_http_response<'a>(
140        &'a mut self,
141        context: &'a Context,
142        response: &'a mut Response,
143    ) -> BoxFuture<'a, ()> {
144        <Self as Middleware<Context>>::process_http_response(self, context, response).boxed()
145    }
146}
147
148struct DynMiddleware<Context>(Box<dyn _Middleware<Context>>);
149impl<Context> Clone for DynMiddleware<Context> {
150    fn clone(&self) -> Self {
151        self.0.dyn_clone()
152    }
153}
154
155pub struct HttpServer<Context: crate::Context> {
156    inner: Server<Context>,
157    middleware: Vector<DynMiddleware<Context>>,
158}
159impl<Context: crate::Context> Clone for HttpServer<Context> {
160    fn clone(&self) -> Self {
161        Self {
162            inner: self.inner.clone(),
163            middleware: self.middleware.clone(),
164        }
165    }
166}
167impl<Context: crate::Context> Server<Context> {
168    pub fn for_http(self) -> HttpServer<Context> {
169        HttpServer {
170            inner: self,
171            middleware: Vector::new(),
172        }
173    }
174    pub fn middleware<T: Middleware<Context>>(self, middleware: T) -> HttpServer<Context> {
175        self.for_http().middleware(middleware)
176    }
177}
178impl<Context: crate::Context> HttpServer<Context> {
179    pub fn middleware<T: Middleware<Context>>(mut self, middleware: T) -> Self {
180        self.middleware
181            .push_back(DynMiddleware(Box::new(middleware)));
182        self
183    }
184    async fn process_http_request(&self, mut req: Request) -> Response {
185        let mut mid = self.middleware.clone();
186        match async {
187            let ctx = (self.inner.make_ctx)().await?;
188            for middleware in mid.iter_mut().rev() {
189                if let Err(e) = middleware.0.process_http_request(&ctx, &mut req).await {
190                    return Ok::<_, RpcError>(e);
191                }
192            }
193            let (_, body) = req.into_parts();
194            match serde_json::from_slice::<SingleOrBatchRpcRequest>(
195                &*body.collect().await.map_err(internal_error)?.to_bytes(),
196            )
197            .map_err(parse_error)?
198            {
199                SingleOrBatchRpcRequest::Single(rpc_req) => {
200                    let mut res = json_http_response(
201                        &self.process_rpc_request(&ctx, &mut mid, rpc_req).await,
202                    );
203                    for middleware in mid.iter_mut() {
204                        middleware.0.process_http_response(&ctx, &mut res).await;
205                    }
206                    Ok(res)
207                }
208                SingleOrBatchRpcRequest::Batch(rpc_reqs) => {
209                    let (mids, rpc_res): (Vec<_>, Vec<_>) =
210                        join_all(rpc_reqs.into_iter().map(|rpc_req| async {
211                            let mut mid = mid.clone();
212                            let res = self.process_rpc_request(&ctx, &mut mid, rpc_req).await;
213                            (mid, res)
214                        }))
215                        .await
216                        .into_iter()
217                        .unzip();
218                    let mut res = json_http_response(&rpc_res);
219                    for mut mid in mids.into_iter().fold(
220                        vec![Vec::with_capacity(rpc_res.len()); mid.len()],
221                        |mut acc, x| {
222                            for (idx, middleware) in x.into_iter().enumerate() {
223                                acc[idx].push(middleware);
224                            }
225                            acc
226                        },
227                    ) {
228                        for middleware in mid.iter_mut() {
229                            middleware.0.process_http_response(&ctx, &mut res).await;
230                        }
231                    }
232                    Ok(res)
233                }
234            }
235        }
236        .await
237        {
238            Ok(a) => a,
239            Err(e) => json_http_response(&RpcResponse {
240                id: None,
241                result: Err(e),
242            }),
243        }
244    }
245    async fn process_rpc_request(
246        &self,
247        ctx: &Context,
248        mid: &mut Vector<DynMiddleware<Context>>,
249        mut req: RpcRequest,
250    ) -> RpcResponse {
251        let metadata = Value::Object(
252            self.inner
253                .root_handler
254                .metadata(
255                    match self
256                        .inner
257                        .root_handler
258                        .method_from_dots(req.method.as_str())
259                    {
260                        Some(a) => a,
261                        None => {
262                            return RpcResponse {
263                                id: req.id,
264                                result: Err(yajrc::METHOD_NOT_FOUND_ERROR),
265                            }
266                        }
267                    },
268                )
269                .into_iter()
270                .map(|(key, value)| (key.into(), value))
271                .collect(),
272        );
273        let mut res = async {
274            for middleware in mid.iter_mut().rev() {
275                if let Err(res) = middleware
276                    .0
277                    .process_rpc_request(ctx, metadata.clone(), &mut req)
278                    .await
279                {
280                    return res;
281                }
282            }
283            self.inner.handle_single_request(req).await
284        }
285        .await;
286        for middleware in mid.iter_mut() {
287            middleware.0.process_rpc_response(ctx, &mut res).await;
288        }
289        res
290    }
291    pub fn handle(&self, req: Request) -> BoxFuture<'static, Response> {
292        let server = self.clone();
293        async move { server.process_http_request(req).await }.boxed()
294    }
295}
296
297impl<Context: crate::Context> Handler<(), ()> for HttpServer<Context> {
298    type Future = BoxFuture<'static, Response>;
299    fn call(self, req: Request, _: ()) -> Self::Future {
300        self.handle(req)
301    }
302}