rspc_legacy/internal/
middleware.rs

1use std::{fmt, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
2
3use futures::Stream;
4use serde_json::Value;
5use specta::Type;
6
7use crate::{ExecError, MiddlewareLike};
8
9pub trait MiddlewareBuilderLike<TCtx> {
10    type LayerContext: 'static;
11
12    fn build<T>(&self, next: T) -> Box<dyn Layer<TCtx>>
13    where
14        T: Layer<Self::LayerContext>;
15}
16
17pub struct MiddlewareMerger<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TIncomingMiddleware>
18where
19    TMiddleware: MiddlewareBuilderLike<TCtx, LayerContext = TLayerCtx>,
20    TIncomingMiddleware: MiddlewareBuilderLike<TLayerCtx, LayerContext = TNewLayerCtx>,
21{
22    pub middleware: TMiddleware,
23    pub middleware2: TIncomingMiddleware,
24    pub phantom: PhantomData<(TCtx, TLayerCtx)>,
25}
26
27impl<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TIncomingMiddleware> MiddlewareBuilderLike<TCtx>
28    for MiddlewareMerger<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TIncomingMiddleware>
29where
30    TCtx: 'static,
31    TLayerCtx: 'static,
32    TNewLayerCtx: 'static,
33    TMiddleware: MiddlewareBuilderLike<TCtx, LayerContext = TLayerCtx>,
34    TIncomingMiddleware: MiddlewareBuilderLike<TLayerCtx, LayerContext = TNewLayerCtx>,
35{
36    type LayerContext = TNewLayerCtx;
37
38    fn build<T>(&self, next: T) -> Box<dyn Layer<TCtx>>
39    where
40        T: Layer<Self::LayerContext>,
41    {
42        self.middleware.build(self.middleware2.build(next))
43    }
44}
45
46pub struct MiddlewareLayerBuilder<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TNewMiddleware>
47where
48    TCtx: Send + Sync + 'static,
49    TLayerCtx: Send + Sync + 'static,
50    TNewLayerCtx: Send + Sync + 'static,
51    TMiddleware: MiddlewareBuilderLike<TCtx, LayerContext = TLayerCtx> + Send + 'static,
52    TNewMiddleware: MiddlewareLike<TLayerCtx, NewCtx = TNewLayerCtx>,
53{
54    pub middleware: TMiddleware,
55    pub mw: TNewMiddleware,
56    pub phantom: PhantomData<(TCtx, TLayerCtx, TNewLayerCtx)>,
57}
58
59impl<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TNewMiddleware> MiddlewareBuilderLike<TCtx>
60    for MiddlewareLayerBuilder<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TNewMiddleware>
61where
62    TCtx: Send + Sync + 'static,
63    TLayerCtx: Send + Sync + 'static,
64    TNewLayerCtx: Send + Sync + 'static,
65    TMiddleware: MiddlewareBuilderLike<TCtx, LayerContext = TLayerCtx> + Send + 'static,
66    TNewMiddleware: MiddlewareLike<TLayerCtx, NewCtx = TNewLayerCtx> + Send + Sync + 'static,
67{
68    type LayerContext = TNewLayerCtx;
69
70    fn build<T>(&self, next: T) -> Box<dyn Layer<TCtx>>
71    where
72        T: Layer<Self::LayerContext> + Sync,
73    {
74        self.middleware.build(MiddlewareLayer {
75            next: Arc::new(next),
76            mw: self.mw.clone(),
77            phantom: PhantomData,
78        })
79    }
80}
81
82pub struct MiddlewareLayer<TLayerCtx, TNewLayerCtx, TMiddleware, TNewMiddleware>
83where
84    TLayerCtx: Send + 'static,
85    TNewLayerCtx: Send + 'static,
86    TMiddleware: Layer<TNewLayerCtx> + 'static,
87    TNewMiddleware: MiddlewareLike<TLayerCtx, NewCtx = TNewLayerCtx> + Send + Sync + 'static,
88{
89    next: Arc<TMiddleware>, // TODO: Avoid arcing this if possible
90    mw: TNewMiddleware,
91    phantom: PhantomData<(TLayerCtx, TNewLayerCtx)>,
92}
93
94impl<TLayerCtx, TNewLayerCtx, TMiddleware, TNewMiddleware> Layer<TLayerCtx>
95    for MiddlewareLayer<TLayerCtx, TNewLayerCtx, TMiddleware, TNewMiddleware>
96where
97    TLayerCtx: Send + Sync + 'static,
98    TNewLayerCtx: Send + Sync + 'static,
99    TMiddleware: Layer<TNewLayerCtx> + Sync + 'static,
100    TNewMiddleware: MiddlewareLike<TLayerCtx, NewCtx = TNewLayerCtx> + Send + Sync + 'static,
101{
102    fn call(
103        &self,
104        ctx: TLayerCtx,
105        input: Value,
106        req: RequestContext,
107    ) -> Result<LayerResult, ExecError> {
108        self.mw.handle(ctx, input, req, self.next.clone())
109    }
110}
111
112pub struct BaseMiddleware<TCtx>(PhantomData<TCtx>)
113where
114    TCtx: 'static;
115
116impl<TCtx> Default for BaseMiddleware<TCtx>
117where
118    TCtx: 'static,
119{
120    fn default() -> Self {
121        Self(PhantomData)
122    }
123}
124
125impl<TCtx> MiddlewareBuilderLike<TCtx> for BaseMiddleware<TCtx>
126where
127    TCtx: Send + 'static,
128{
129    type LayerContext = TCtx;
130
131    fn build<T>(&self, next: T) -> Box<dyn Layer<TCtx>>
132    where
133        T: Layer<Self::LayerContext>,
134    {
135        Box::new(next)
136    }
137}
138
139// TODO: Rename this so it doesn't conflict with the middleware builder struct
140pub trait Layer<TLayerCtx: 'static>: Send + Sync + 'static {
141    fn call(&self, a: TLayerCtx, b: Value, c: RequestContext) -> Result<LayerResult, ExecError>;
142}
143
144pub struct ResolverLayer<TLayerCtx, T>
145where
146    TLayerCtx: Send + Sync + 'static,
147    T: Fn(TLayerCtx, Value, RequestContext) -> Result<LayerResult, ExecError>
148        + Send
149        + Sync
150        + 'static,
151{
152    pub func: T,
153    pub phantom: PhantomData<TLayerCtx>,
154}
155
156impl<T, TLayerCtx> Layer<TLayerCtx> for ResolverLayer<TLayerCtx, T>
157where
158    TLayerCtx: Send + Sync + 'static,
159    T: Fn(TLayerCtx, Value, RequestContext) -> Result<LayerResult, ExecError>
160        + Send
161        + Sync
162        + 'static,
163{
164    fn call(&self, a: TLayerCtx, b: Value, c: RequestContext) -> Result<LayerResult, ExecError> {
165        (self.func)(a, b, c)
166    }
167}
168
169impl<TLayerCtx> Layer<TLayerCtx> for Box<dyn Layer<TLayerCtx> + 'static>
170where
171    TLayerCtx: 'static,
172{
173    fn call(&self, a: TLayerCtx, b: Value, c: RequestContext) -> Result<LayerResult, ExecError> {
174        (**self).call(a, b, c)
175    }
176}
177
178// TODO: This is a clone of `rspc::ProcedureKind`. I don't like us having both but we need it for the dependency tree to work.
179#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Type)]
180#[specta(rename_all = "camelCase")]
181pub enum ProcedureKind {
182    Query,
183    Mutation,
184    Subscription,
185}
186
187impl fmt::Display for ProcedureKind {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        match self {
190            Self::Query => write!(f, "Query"),
191            Self::Mutation => write!(f, "Mutation"),
192            Self::Subscription => write!(f, "Subscription"),
193        }
194    }
195}
196
197// TODO: Maybe rename to `Request` or something else. Also move into Public API cause it might be used in middleware
198#[derive(Debug, Clone)]
199pub struct RequestContext {
200    pub kind: ProcedureKind,
201    pub path: String, // TODO: String slice??
202}
203
204pub enum ValueOrStream {
205    Value(Value),
206    Stream(Pin<Box<dyn Stream<Item = Result<Value, ExecError>> + Send>>),
207}
208
209pub enum ValueOrStreamOrFutureStream {
210    Value(Value),
211    Stream(Pin<Box<dyn Stream<Item = Result<Value, ExecError>> + Send>>),
212}
213
214pub enum LayerResult {
215    Future(Pin<Box<dyn Future<Output = Result<Value, ExecError>> + Send>>),
216    Stream(Pin<Box<dyn Stream<Item = Result<Value, ExecError>> + Send>>),
217    FutureValueOrStream(Pin<Box<dyn Future<Output = Result<ValueOrStream, ExecError>> + Send>>),
218    FutureValueOrStreamOrFutureStream(
219        Pin<Box<dyn Future<Output = Result<ValueOrStreamOrFutureStream, ExecError>> + Send>>,
220    ),
221    Ready(Result<Value, ExecError>),
222}
223
224impl LayerResult {
225    pub async fn into_value_or_stream(self) -> Result<ValueOrStream, ExecError> {
226        match self {
227            LayerResult::Stream(stream) => Ok(ValueOrStream::Stream(stream)),
228            LayerResult::Future(fut) => Ok(ValueOrStream::Value(fut.await?)),
229            LayerResult::FutureValueOrStream(fut) => Ok(fut.await?),
230            LayerResult::FutureValueOrStreamOrFutureStream(fut) => Ok(match fut.await? {
231                ValueOrStreamOrFutureStream::Value(val) => ValueOrStream::Value(val),
232                ValueOrStreamOrFutureStream::Stream(stream) => ValueOrStream::Stream(stream),
233            }),
234            LayerResult::Ready(res) => Ok(ValueOrStream::Value(res?)),
235        }
236    }
237}