rspc_legacy/
middleware.rs

1use futures::StreamExt;
2use serde_json::Value;
3use std::{future::Future, marker::PhantomData, sync::Arc};
4
5use crate::{
6    internal::{Layer, LayerResult, RequestContext, ValueOrStream, ValueOrStreamOrFutureStream},
7    ExecError,
8};
9
10pub trait MiddlewareLike<TLayerCtx>: Clone {
11    type State: Clone + Send + Sync + 'static;
12    type NewCtx: Send + 'static;
13
14    fn handle<TMiddleware: Layer<Self::NewCtx> + 'static>(
15        &self,
16        ctx: TLayerCtx,
17        input: Value,
18        req: RequestContext,
19        next: Arc<TMiddleware>,
20    ) -> Result<LayerResult, ExecError>;
21}
22pub struct MiddlewareContext<TLayerCtx, TNewCtx = TLayerCtx, TState = ()>
23where
24    TState: Send,
25{
26    pub state: TState,
27    pub input: Value,
28    pub ctx: TNewCtx,
29    pub req: RequestContext,
30    pub phantom: PhantomData<TLayerCtx>,
31}
32
33// This will match were TState is the default (`()`) so it shouldn't let you call it if you've already swapped the generic
34impl<TLayerCtx, TNewCtx> MiddlewareContext<TLayerCtx, TNewCtx, ()>
35where
36    TLayerCtx: Send,
37{
38    pub fn with_state<TState>(self, state: TState) -> MiddlewareContext<TLayerCtx, TNewCtx, TState>
39    where
40        TState: Send,
41    {
42        MiddlewareContext {
43            state,
44            input: self.input,
45            ctx: self.ctx,
46            req: self.req,
47            phantom: PhantomData,
48        }
49    }
50}
51
52// This will match were TNewCtx is the default (`TCtx`) so it shouldn't let you call it if you've already swapped the generic
53impl<TLayerCtx, TState> MiddlewareContext<TLayerCtx, TLayerCtx, TState>
54where
55    TLayerCtx: Send,
56    TState: Send,
57{
58    pub fn with_ctx<TNewCtx>(
59        self,
60        new_ctx: TNewCtx,
61    ) -> MiddlewareContext<TLayerCtx, TNewCtx, TState> {
62        MiddlewareContext {
63            state: self.state,
64            input: self.input,
65            ctx: new_ctx,
66            req: self.req,
67            phantom: PhantomData,
68        }
69    }
70}
71
72pub struct Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
73where
74    TState: Send,
75    TLayerCtx: Send,
76    THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
77    THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
78        + Send
79        + 'static,
80{
81    handler: THandlerFunc,
82    phantom: PhantomData<(TState, TLayerCtx)>,
83}
84
85impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut> Clone
86    for Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
87where
88    TState: Send,
89    TLayerCtx: Send,
90    THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
91    THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
92        + Send
93        + 'static,
94{
95    fn clone(&self) -> Self {
96        Self {
97            handler: self.handler.clone(),
98            phantom: PhantomData,
99        }
100    }
101}
102
103pub struct MiddlewareBuilder<TLayerCtx>(pub PhantomData<TLayerCtx>)
104where
105    TLayerCtx: Send;
106
107impl<TLayerCtx> MiddlewareBuilder<TLayerCtx>
108where
109    TLayerCtx: Send,
110{
111    pub fn middleware<TState, TNewCtx, THandlerFunc, THandlerFut>(
112        &self,
113        handler: THandlerFunc,
114    ) -> Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
115    where
116        TState: Send,
117        THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
118        THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
119            + Send
120            + 'static,
121    {
122        Middleware {
123            handler,
124            phantom: PhantomData,
125        }
126    }
127}
128
129impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
130    Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
131where
132    TState: Send,
133    TLayerCtx: Send,
134    THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
135    THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
136        + Send
137        + 'static,
138{
139    pub fn resp<TRespHandlerFunc, TRespHandlerFut>(
140        self,
141        handler: TRespHandlerFunc,
142    ) -> MiddlewareWithResponseHandler<
143        TState,
144        TLayerCtx,
145        TNewCtx,
146        THandlerFunc,
147        THandlerFut,
148        TRespHandlerFunc,
149        TRespHandlerFut,
150    >
151    where
152        TRespHandlerFunc: Fn(TState, Value) -> TRespHandlerFut + Clone + Sync + Send + 'static,
153        TRespHandlerFut: Future<Output = Result<Value, crate::Error>> + Send + 'static,
154    {
155        MiddlewareWithResponseHandler {
156            handler: self.handler,
157            resp_handler: handler,
158            phantom: PhantomData,
159        }
160    }
161}
162
163pub struct MiddlewareWithResponseHandler<
164    TState,
165    TLayerCtx,
166    TNewCtx,
167    THandlerFunc,
168    THandlerFut,
169    TRespHandlerFunc,
170    TRespHandlerFut,
171> where
172    TState: Send,
173    TLayerCtx: Send,
174    THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
175    THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
176        + Send
177        + 'static,
178    TRespHandlerFunc: Fn(TState, Value) -> TRespHandlerFut + Clone + Sync + Send + 'static,
179    TRespHandlerFut: Future<Output = Result<Value, crate::Error>> + Send + 'static,
180{
181    handler: THandlerFunc,
182    resp_handler: TRespHandlerFunc,
183    phantom: PhantomData<(TState, TLayerCtx)>,
184}
185
186impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut, TRespHandlerFunc, TRespHandlerFut> Clone
187    for MiddlewareWithResponseHandler<
188        TState,
189        TLayerCtx,
190        TNewCtx,
191        THandlerFunc,
192        THandlerFut,
193        TRespHandlerFunc,
194        TRespHandlerFut,
195    >
196where
197    TState: Send,
198    TLayerCtx: Send,
199    THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
200    THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
201        + Send
202        + 'static,
203    TRespHandlerFunc: Fn(TState, Value) -> TRespHandlerFut + Clone + Sync + Send + 'static,
204    TRespHandlerFut: Future<Output = Result<Value, crate::Error>> + Send + 'static,
205{
206    fn clone(&self) -> Self {
207        Self {
208            handler: self.handler.clone(),
209            resp_handler: self.resp_handler.clone(),
210            phantom: PhantomData,
211        }
212    }
213}
214
215impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut> MiddlewareLike<TLayerCtx>
216    for Middleware<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut>
217where
218    TState: Clone + Send + Sync + 'static,
219    TLayerCtx: Send,
220    TNewCtx: Send + 'static,
221    THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
222    THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
223        + Send
224        + 'static,
225{
226    type State = TState;
227    type NewCtx = TNewCtx;
228
229    fn handle<TMiddleware: Layer<Self::NewCtx> + 'static>(
230        &self,
231        ctx: TLayerCtx,
232        input: Value,
233        req: RequestContext,
234        next: Arc<TMiddleware>,
235    ) -> Result<LayerResult, ExecError> {
236        let handler = (self.handler)(MiddlewareContext {
237            state: (),
238            ctx,
239            input,
240            req,
241            phantom: PhantomData,
242        });
243
244        Ok(LayerResult::FutureValueOrStream(Box::pin(async move {
245            let handler = handler.await?;
246            next.call(handler.ctx, handler.input, handler.req)?
247                .into_value_or_stream()
248                .await
249        })))
250    }
251}
252
253enum FutOrValue<T: Future<Output = Result<Value, crate::Error>>> {
254    Fut(T),
255    Value(Result<Value, ExecError>),
256}
257
258impl<TState, TLayerCtx, TNewCtx, THandlerFunc, THandlerFut, TRespHandlerFunc, TRespHandlerFut>
259    MiddlewareLike<TLayerCtx>
260    for MiddlewareWithResponseHandler<
261        TState,
262        TLayerCtx,
263        TNewCtx,
264        THandlerFunc,
265        THandlerFut,
266        TRespHandlerFunc,
267        TRespHandlerFut,
268    >
269where
270    TState: Clone + Send + Sync + 'static,
271    TLayerCtx: Send + 'static,
272    TNewCtx: Send + 'static,
273    THandlerFunc: Fn(MiddlewareContext<TLayerCtx, TLayerCtx, ()>) -> THandlerFut + Clone,
274    THandlerFut: Future<Output = Result<MiddlewareContext<TLayerCtx, TNewCtx, TState>, crate::Error>>
275        + Send
276        + 'static,
277    TRespHandlerFunc: Fn(TState, Value) -> TRespHandlerFut + Clone + Sync + Send + 'static,
278    TRespHandlerFut: Future<Output = Result<Value, crate::Error>> + Send + 'static,
279{
280    type State = TState;
281    type NewCtx = TNewCtx;
282
283    fn handle<TMiddleware: Layer<Self::NewCtx> + 'static>(
284        &self,
285        ctx: TLayerCtx,
286        input: Value,
287        req: RequestContext,
288        next: Arc<TMiddleware>,
289    ) -> Result<LayerResult, ExecError> {
290        let handler = (self.handler)(MiddlewareContext {
291            state: (),
292            ctx,
293            input,
294            req,
295            // new_ctx: None,
296            phantom: PhantomData,
297        });
298
299        let f = self.resp_handler.clone(); // TODO: Runtime clone is bad. Avoid this!
300
301        Ok(LayerResult::FutureValueOrStreamOrFutureStream(Box::pin(
302            async move {
303                let handler = handler.await?;
304
305                Ok(
306                    match next
307                        .call(handler.ctx, handler.input, handler.req)?
308                        .into_value_or_stream()
309                        .await?
310                    {
311                        ValueOrStream::Value(v) => {
312                            ValueOrStreamOrFutureStream::Value(f(handler.state, v).await?)
313                        }
314                        ValueOrStream::Stream(s) => {
315                            ValueOrStreamOrFutureStream::Stream(Box::pin(s.then(move |v| {
316                                let v = match v {
317                                    Ok(v) => FutOrValue::Fut(f(handler.state.clone(), v)),
318                                    e => FutOrValue::Value(e),
319                                };
320
321                                async move {
322                                    match v {
323                                        FutOrValue::Fut(fut) => {
324                                            fut.await.map_err(ExecError::ErrResolverError)
325                                        }
326                                        FutOrValue::Value(v) => v,
327                                    }
328                                }
329                            })))
330                        }
331                    },
332                )
333            },
334        )))
335    }
336}
337
338// TODO: Middleware functions should be able to be async or sync & return a value or result