1use futures::StreamExt;
2use serde_json::Value;
3use std::{future::Future, marker::PhantomData, sync::Arc};
4
5use crate::{
6 internal::{Layer, LayerResult, RequestContext, ValueOrStream, ValueOrStreamOrFutureStream},
7 ExecError,
8};
9
10pub trait MiddlewareLike<TLayerCtx>: Clone {
11 type State: Clone + Send + Sync + 'static;
12 type NewCtx: Send + 'static;
13
14 fn handle<TMiddleware: Layer<Self::NewCtx> + 'static>(
15 &self,
16 ctx: TLayerCtx,
17 input: Value,
18 req: RequestContext,
19 next: Arc<TMiddleware>,
20 ) -> Result<LayerResult, ExecError>;
21}
22pub struct MiddlewareContext<TLayerCtx, TNewCtx = TLayerCtx, TState = ()>
23where
24 TState: Send,
25{
26 pub state: TState,
27 pub input: Value,
28 pub ctx: TNewCtx,
29 pub req: RequestContext,
30 pub phantom: PhantomData<TLayerCtx>,
31}
32
33impl<TLayerCtx, TNewCtx> MiddlewareContext<TLayerCtx, TNewCtx, ()>
35where
36 TLayerCtx: Send,
37{
38 pub fn with_state<TState>(self, state: TState) -> MiddlewareContext<TLayerCtx, TNewCtx, TState>
39 where
40 TState: Send,
41 {
42 MiddlewareContext {
43 state,
44 input: self.input,
45 ctx: self.ctx,
46 req: self.req,
47 phantom: PhantomData,
48 }
49 }
50}
51
52impl<TLayerCtx, TState> MiddlewareContext<TLayerCtx, TLayerCtx, TState>
54where
55 TLayerCtx: Send,
56 TState: Send,
57{
58 pub fn with_ctx<TNewCtx>(
59 self,
60 new_ctx: TNewCtx,
61 ) -> MiddlewareContext<TLayerCtx, TNewCtx, TState> {
62 MiddlewareContext {
63 state: self.state,
64 input: self.input,
65 ctx: new_ctx,
66 req: self.req,
67 phantom: PhantomData,
68 }
69 }
70}
71
72pub struct Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
73where
74 TState: Send,
75 TLayerCtx: Send,
76 THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
77 THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
78 + Send
79 + 'static,
80{
81 handler: THandlerFunc,
82 phantom: PhantomData<(TState, TLayerCtx)>,
83}
84
85impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut> Clone
86 for Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
87where
88 TState: Send,
89 TLayerCtx: Send,
90 THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
91 THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
92 + Send
93 + 'static,
94{
95 fn clone(&self) -> Self {
96 Self {
97 handler: self.handler.clone(),
98 phantom: PhantomData,
99 }
100 }
101}
102
103pub struct MiddlewareBuilder<TLayerCtx>(pub PhantomData<TLayerCtx>)
104where
105 TLayerCtx: Send;
106
107impl<TLayerCtx> MiddlewareBuilder<TLayerCtx>
108where
109 TLayerCtx: Send,
110{
111 pub fn middleware<TState, TNewCtx, THandlerFunc, THandlerFut>(
112 &self,
113 handler: THandlerFunc,
114 ) -> Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
115 where
116 TState: Send,
117 THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
118 THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
119 + Send
120 + 'static,
121 {
122 Middleware {
123 handler,
124 phantom: PhantomData,
125 }
126 }
127}
128
129impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
130 Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
131where
132 TState: Send,
133 TLayerCtx: Send,
134 THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
135 THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
136 + Send
137 + 'static,
138{
139 pub fn resp<TRespHandlerFunc, TRespHandlerFut>(
140 self,
141 handler: TRespHandlerFunc,
142 ) -> MiddlewareWithResponseHandler<
143 TState,
144 TLayerCtx,
145 TNewCtx,
146 THandlerFunc,
147 THandlerFut,
148 TRespHandlerFunc,
149 TRespHandlerFut,
150 >
151 where
152 TRespHandlerFunc: Fn(TState, Value) -> TRespHandlerFut + Clone + Sync + Send + 'static,
153 TRespHandlerFut: Future<Output = Result<Value, crate::Error>> + Send + 'static,
154 {
155 MiddlewareWithResponseHandler {
156 handler: self.handler,
157 resp_handler: handler,
158 phantom: PhantomData,
159 }
160 }
161}
162
163pub struct MiddlewareWithResponseHandler<
164 TState,
165 TLayerCtx,
166 TNewCtx,
167 THandlerFunc,
168 THandlerFut,
169 TRespHandlerFunc,
170 TRespHandlerFut,
171> where
172 TState: Send,
173 TLayerCtx: Send,
174 THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
175 THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
176 + Send
177 + 'static,
178 TRespHandlerFunc: Fn(TState, Value) -> TRespHandlerFut + Clone + Sync + Send + 'static,
179 TRespHandlerFut: Future<Output = Result<Value, crate::Error>> + Send + 'static,
180{
181 handler: THandlerFunc,
182 resp_handler: TRespHandlerFunc,
183 phantom: PhantomData<(TState, TLayerCtx)>,
184}
185
186impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut, TRespHandlerFunc, TRespHandlerFut> Clone
187 for MiddlewareWithResponseHandler<
188 TState,
189 TLayerCtx,
190 TNewCtx,
191 THandlerFunc,
192 THandlerFut,
193 TRespHandlerFunc,
194 TRespHandlerFut,
195 >
196where
197 TState: Send,
198 TLayerCtx: Send,
199 THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
200 THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
201 + Send
202 + 'static,
203 TRespHandlerFunc: Fn(TState, Value) -> TRespHandlerFut + Clone + Sync + Send + 'static,
204 TRespHandlerFut: Future<Output = Result<Value, crate::Error>> + Send + 'static,
205{
206 fn clone(&self) -> Self {
207 Self {
208 handler: self.handler.clone(),
209 resp_handler: self.resp_handler.clone(),
210 phantom: PhantomData,
211 }
212 }
213}
214
215impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut> MiddlewareLike<TLayerCtx>
216 for Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
217where
218 TState: Clone + Send + Sync + 'static,
219 TLayerCtx: Send,
220 TNewCtx: Send + 'static,
221 THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
222 THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
223 + Send
224 + 'static,
225{
226 type State = TState;
227 type NewCtx = TNewCtx;
228
229 fn handle<TMiddleware: Layer<Self::NewCtx> + 'static>(
230 &self,
231 ctx: TLayerCtx,
232 input: Value,
233 req: RequestContext,
234 next: Arc<TMiddleware>,
235 ) -> Result<LayerResult, ExecError> {
236 let handler = (self.handler)(MiddlewareContext {
237 state: (),
238 ctx,
239 input,
240 req,
241 phantom: PhantomData,
242 });
243
244 Ok(LayerResult::FutureValueOrStream(Box::pin(async move {
245 let handler = handler.await?;
246 next.call(handler.ctx, handler.input, handler.req)?
247 .into_value_or_stream()
248 .await
249 })))
250 }
251}
252
253enum FutOrValue<T: Future<Output = Result<Value, crate::Error>>> {
254 Fut(T),
255 Value(Result<Value, ExecError>),
256}
257
258impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut, TRespHandlerFunc, TRespHandlerFut>
259 MiddlewareLike<TLayerCtx>
260 for MiddlewareWithResponseHandler<
261 TState,
262 TLayerCtx,
263 TNewCtx,
264 THandlerFunc,
265 THandlerFut,
266 TRespHandlerFunc,
267 TRespHandlerFut,
268 >
269where
270 TState: Clone + Send + Sync + 'static,
271 TLayerCtx: Send + 'static,
272 TNewCtx: Send + 'static,
273 THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
274 THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
275 + Send
276 + 'static,
277 TRespHandlerFunc: Fn(TState, Value) -> TRespHandlerFut + Clone + Sync + Send + 'static,
278 TRespHandlerFut: Future<Output = Result<Value, crate::Error>> + Send + 'static,
279{
280 type State = TState;
281 type NewCtx = TNewCtx;
282
283 fn handle<TMiddleware: Layer<Self::NewCtx> + 'static>(
284 &self,
285 ctx: TLayerCtx,
286 input: Value,
287 req: RequestContext,
288 next: Arc<TMiddleware>,
289 ) -> Result<LayerResult, ExecError> {
290 let handler = (self.handler)(MiddlewareContext {
291 state: (),
292 ctx,
293 input,
294 req,
295 phantom: PhantomData,
297 });
298
299 let f = self.resp_handler.clone(); Ok(LayerResult::FutureValueOrStreamOrFutureStream(Box::pin(
302 async move {
303 let handler = handler.await?;
304
305 Ok(
306 match next
307 .call(handler.ctx, handler.input, handler.req)?
308 .into_value_or_stream()
309 .await?
310 {
311 ValueOrStream::Value(v) => {
312 ValueOrStreamOrFutureStream::Value(f(handler.state, v).await?)
313 }
314 ValueOrStream::Stream(s) => {
315 ValueOrStreamOrFutureStream::Stream(Box::pin(s.then(move |v| {
316 let v = match v {
317 Ok(v) => FutOrValue::Fut(f(handler.state.clone(), v)),
318 e => FutOrValue::Value(e),
319 };
320
321 async move {
322 match v {
323 FutOrValue::Fut(fut) => {
324 fut.await.map_err(ExecError::ErrResolverError)
325 }
326 FutOrValue::Value(v) => v,
327 }
328 }
329 })))
330 }
331 },
332 )
333 },
334 )))
335 }
336}
337
338