1use axum::body::Body;
2use axum::extract::Request;
3use axum::handler::Handler;
4use axum::response::Response;
5use futures::future::{join_all, BoxFuture};
6use futures::{Future, FutureExt};
7use http::header::{CONTENT_LENGTH, CONTENT_TYPE};
8use http_body_util::BodyExt;
9use imbl_value::imbl::Vector;
10use imbl_value::Value;
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use yajrc::{RpcError, RpcMethod};
14
15use crate::server::{RpcRequest, RpcResponse, SingleOrBatchRpcRequest};
16use crate::util::{internal_error, parse_error};
17use crate::{HandleAny, Server};
18
19const FALLBACK_ERROR: &str = "{\"error\":{\"code\":-32603,\"message\":\"Internal error\",\"data\":\"Failed to serialize rpc response\"}}";
20
21pub fn fallback_rpc_error_response() -> Response {
22 Response::builder()
23 .header(CONTENT_TYPE, "application/json")
24 .header(CONTENT_LENGTH, FALLBACK_ERROR.len())
25 .body(Body::from(FALLBACK_ERROR.as_bytes()))
26 .unwrap()
27}
28
29pub fn json_http_response<T: Serialize>(t: &T) -> Response {
30 let body = match serde_json::to_vec(t) {
31 Ok(a) => a,
32 Err(_) => return fallback_rpc_error_response(),
33 };
34 Response::builder()
35 .header(CONTENT_TYPE, "application/json")
36 .header(CONTENT_LENGTH, body.len())
37 .body(Body::from(body))
38 .unwrap_or_else(|_| fallback_rpc_error_response())
39}
40
41pub trait Middleware<Context: Send + 'static>: Clone + Send + Sync + 'static {
42 type Metadata: DeserializeOwned + Send + 'static;
43 #[allow(unused_variables)]
44 fn process_http_request(
45 &mut self,
46 context: &Context,
47 request: &mut Request,
48 ) -> impl Future<Output = Result<(), Response>> + Send {
49 async { Ok(()) }
50 }
51 #[allow(unused_variables)]
52 fn process_rpc_request(
53 &mut self,
54 context: &Context,
55 metadata: Self::Metadata,
56 request: &mut RpcRequest,
57 ) -> impl Future<Output = Result<(), RpcResponse>> + Send {
58 async { Ok(()) }
59 }
60 #[allow(unused_variables)]
61 fn process_rpc_response(
62 &mut self,
63 context: &Context,
64 response: &mut RpcResponse,
65 ) -> impl Future<Output = ()> + Send {
66 async { () }
67 }
68 #[allow(unused_variables)]
69 fn process_http_response(
70 &mut self,
71 context: &Context,
72 response: &mut Response,
73 ) -> impl Future<Output = ()> + Send {
74 async { () }
75 }
76}
77
78#[allow(private_bounds)]
79trait _Middleware<Context>: Send + Sync {
80 fn dyn_clone(&self) -> DynMiddleware<Context>;
81 fn process_http_request<'a>(
82 &'a mut self,
83 context: &'a Context,
84 request: &'a mut Request,
85 ) -> BoxFuture<'a, Result<(), Response>>;
86 fn process_rpc_request<'a>(
87 &'a mut self,
88 context: &'a Context,
89 metadata: Value,
90 request: &'a mut RpcRequest,
91 ) -> BoxFuture<'a, Result<(), RpcResponse>>;
92 fn process_rpc_response<'a>(
93 &'a mut self,
94
95 context: &'a Context,
96 response: &'a mut RpcResponse,
97 ) -> BoxFuture<'a, ()>;
98 fn process_http_response<'a>(
99 &'a mut self,
100 context: &'a Context,
101 response: &'a mut Response,
102 ) -> BoxFuture<'a, ()>;
103}
104impl<Context: Send + 'static, T: Middleware<Context> + Send + Sync> _Middleware<Context> for T {
105 fn dyn_clone(&self) -> DynMiddleware<Context> {
106 DynMiddleware(Box::new(<Self as Clone>::clone(&self)))
107 }
108 fn process_http_request<'a>(
109 &'a mut self,
110 context: &'a Context,
111 request: &'a mut Request,
112 ) -> BoxFuture<'a, Result<(), Response>> {
113 <Self as Middleware<Context>>::process_http_request(self, context, request).boxed()
114 }
115 fn process_rpc_request<'a>(
116 &'a mut self,
117 context: &'a Context,
118 metadata: Value,
119 request: &'a mut RpcRequest,
120 ) -> BoxFuture<'a, Result<(), RpcResponse>> {
121 <Self as Middleware<Context>>::process_rpc_request(
122 self,
123 context,
124 match imbl_value::from_value(metadata) {
125 Ok(a) => a,
126 Err(e) => return async { Err(internal_error(e).into()) }.boxed(),
127 },
128 request,
129 )
130 .boxed()
131 }
132 fn process_rpc_response<'a>(
133 &'a mut self,
134 context: &'a Context,
135 response: &'a mut RpcResponse,
136 ) -> BoxFuture<'a, ()> {
137 <Self as Middleware<Context>>::process_rpc_response(self, context, response).boxed()
138 }
139 fn process_http_response<'a>(
140 &'a mut self,
141 context: &'a Context,
142 response: &'a mut Response,
143 ) -> BoxFuture<'a, ()> {
144 <Self as Middleware<Context>>::process_http_response(self, context, response).boxed()
145 }
146}
147
148struct DynMiddleware<Context>(Box<dyn _Middleware<Context>>);
149impl<Context> Clone for DynMiddleware<Context> {
150 fn clone(&self) -> Self {
151 self.0.dyn_clone()
152 }
153}
154
155pub struct HttpServer<Context: crate::Context> {
156 inner: Server<Context>,
157 middleware: Vector<DynMiddleware<Context>>,
158}
159impl<Context: crate::Context> Clone for HttpServer<Context> {
160 fn clone(&self) -> Self {
161 Self {
162 inner: self.inner.clone(),
163 middleware: self.middleware.clone(),
164 }
165 }
166}
167impl<Context: crate::Context> Server<Context> {
168 pub fn for_http(self) -> HttpServer<Context> {
169 HttpServer {
170 inner: self,
171 middleware: Vector::new(),
172 }
173 }
174 pub fn middleware<T: Middleware<Context>>(self, middleware: T) -> HttpServer<Context> {
175 self.for_http().middleware(middleware)
176 }
177}
178impl<Context: crate::Context> HttpServer<Context> {
179 pub fn middleware<T: Middleware<Context>>(mut self, middleware: T) -> Self {
180 self.middleware
181 .push_back(DynMiddleware(Box::new(middleware)));
182 self
183 }
184 async fn process_http_request(&self, mut req: Request) -> Response {
185 let mut mid = self.middleware.clone();
186 match async {
187 let ctx = (self.inner.make_ctx)().await?;
188 for middleware in mid.iter_mut().rev() {
189 if let Err(e) = middleware.0.process_http_request(&ctx, &mut req).await {
190 return Ok::<_, RpcError>(e);
191 }
192 }
193 let (_, body) = req.into_parts();
194 match serde_json::from_slice::<SingleOrBatchRpcRequest>(
195 &*body.collect().await.map_err(internal_error)?.to_bytes(),
196 )
197 .map_err(parse_error)?
198 {
199 SingleOrBatchRpcRequest::Single(rpc_req) => {
200 let mut res = json_http_response(
201 &self.process_rpc_request(&ctx, &mut mid, rpc_req).await,
202 );
203 for middleware in mid.iter_mut() {
204 middleware.0.process_http_response(&ctx, &mut res).await;
205 }
206 Ok(res)
207 }
208 SingleOrBatchRpcRequest::Batch(rpc_reqs) => {
209 let (mids, rpc_res): (Vec<_>, Vec<_>) =
210 join_all(rpc_reqs.into_iter().map(|rpc_req| async {
211 let mut mid = mid.clone();
212 let res = self.process_rpc_request(&ctx, &mut mid, rpc_req).await;
213 (mid, res)
214 }))
215 .await
216 .into_iter()
217 .unzip();
218 let mut res = json_http_response(&rpc_res);
219 for mut mid in mids.into_iter().fold(
220 vec![Vec::with_capacity(rpc_res.len()); mid.len()],
221 |mut acc, x| {
222 for (idx, middleware) in x.into_iter().enumerate() {
223 acc[idx].push(middleware);
224 }
225 acc
226 },
227 ) {
228 for middleware in mid.iter_mut() {
229 middleware.0.process_http_response(&ctx, &mut res).await;
230 }
231 }
232 Ok(res)
233 }
234 }
235 }
236 .await
237 {
238 Ok(a) => a,
239 Err(e) => json_http_response(&RpcResponse {
240 id: None,
241 result: Err(e),
242 }),
243 }
244 }
245 async fn process_rpc_request(
246 &self,
247 ctx: &Context,
248 mid: &mut Vector<DynMiddleware<Context>>,
249 mut req: RpcRequest,
250 ) -> RpcResponse {
251 let metadata = Value::Object(
252 self.inner
253 .root_handler
254 .metadata(
255 match self
256 .inner
257 .root_handler
258 .method_from_dots(req.method.as_str())
259 {
260 Some(a) => a,
261 None => {
262 return RpcResponse {
263 id: req.id,
264 result: Err(yajrc::METHOD_NOT_FOUND_ERROR),
265 }
266 }
267 },
268 )
269 .into_iter()
270 .map(|(key, value)| (key.into(), value))
271 .collect(),
272 );
273 let mut res = async {
274 for middleware in mid.iter_mut().rev() {
275 if let Err(res) = middleware
276 .0
277 .process_rpc_request(ctx, metadata.clone(), &mut req)
278 .await
279 {
280 return res;
281 }
282 }
283 self.inner.handle_single_request(req).await
284 }
285 .await;
286 for middleware in mid.iter_mut() {
287 middleware.0.process_rpc_response(ctx, &mut res).await;
288 }
289 res
290 }
291 pub fn handle(&self, req: Request) -> BoxFuture<'static, Response> {
292 let server = self.clone();
293 async move { server.process_http_request(req).await }.boxed()
294 }
295}
296
297impl<Context: crate::Context> Handler<(), ()> for HttpServer<Context> {
298 type Future = BoxFuture<'static, Response>;
299 fn call(self, req: Request, _: ()) -> Self::Future {
300 self.handle(req)
301 }
302}