rspc_legacy/
router_builder.rs

1use std::marker::PhantomData;
2
3use futures::Stream;
4use serde::{de::DeserializeOwned, Serialize};
5use specta::{Type, TypeCollection};
6
7use crate::{
8    internal::{
9        BaseMiddleware, BuiltProcedureBuilder, MiddlewareBuilderLike, MiddlewareLayerBuilder,
10        MiddlewareMerger, ProcedureStore, ResolverLayer, UnbuiltProcedureBuilder,
11    },
12    Config, DoubleArgStreamMarker, ExecError, MiddlewareBuilder, MiddlewareLike, RequestLayer,
13    Resolver, Router, StreamResolver,
14};
15
16pub struct RouterBuilder<
17    TCtx = (), // The is the context the current router was initialised with
18    TMeta = (),
19    TMiddleware = BaseMiddleware<TCtx>,
20> where
21    TCtx: Send + Sync + 'static,
22    TMeta: Send + 'static,
23    TMiddleware: MiddlewareBuilderLike<TCtx> + Send + 'static,
24{
25    config: Config,
26    middleware: TMiddleware,
27    queries: ProcedureStore<TCtx>,
28    mutations: ProcedureStore<TCtx>,
29    subscriptions: ProcedureStore<TCtx>,
30    type_map: TypeCollection,
31    phantom: PhantomData<TMeta>,
32}
33
34#[allow(clippy::new_without_default, clippy::new_ret_no_self)]
35impl<TCtx, TMeta> Router<TCtx, TMeta>
36where
37    TCtx: Send + Sync + 'static,
38    TMeta: Send + 'static,
39{
40    pub fn new() -> RouterBuilder<TCtx, TMeta, BaseMiddleware<TCtx>> {
41        RouterBuilder::new()
42    }
43}
44
45#[allow(clippy::new_without_default)]
46impl<TCtx, TMeta> RouterBuilder<TCtx, TMeta, BaseMiddleware<TCtx>>
47where
48    TCtx: Send + Sync + 'static,
49    TMeta: Send + 'static,
50{
51    pub fn new() -> Self {
52        Self {
53            config: Config::new(),
54            middleware: BaseMiddleware::default(),
55            queries: ProcedureStore::new("query"),
56            mutations: ProcedureStore::new("mutation"),
57            subscriptions: ProcedureStore::new("subscription"),
58            type_map: Default::default(),
59            phantom: PhantomData,
60        }
61    }
62}
63
64impl<TCtx, TLayerCtx, TMeta, TMiddleware> RouterBuilder<TCtx, TMeta, TMiddleware>
65where
66    TCtx: Send + Sync + 'static,
67    TMeta: Send + 'static,
68    TLayerCtx: Send + Sync + 'static,
69    TMiddleware: MiddlewareBuilderLike<TCtx, LayerContext = TLayerCtx> + Send + 'static,
70{
71    /// Attach a configuration to the router. Calling this multiple times will overwrite the previous config.
72    pub fn config(mut self, config: Config) -> Self {
73        self.config = config;
74        self
75    }
76
77    pub fn middleware<TNewMiddleware, TNewLayerCtx>(
78        self,
79        builder: impl Fn(MiddlewareBuilder<TLayerCtx>) -> TNewMiddleware,
80    ) -> RouterBuilder<
81        TCtx,
82        TMeta,
83        MiddlewareLayerBuilder<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TNewMiddleware>,
84    >
85    where
86        TNewLayerCtx: Send + Sync + 'static,
87        TNewMiddleware: MiddlewareLike<TLayerCtx, NewCtx = TNewLayerCtx> + Send + Sync + 'static,
88    {
89        let Self {
90            config,
91            middleware,
92            queries,
93            mutations,
94            subscriptions,
95            type_map,
96            ..
97        } = self;
98
99        let mw = builder(MiddlewareBuilder(PhantomData));
100        RouterBuilder {
101            config,
102            middleware: MiddlewareLayerBuilder {
103                middleware,
104                mw,
105                phantom: PhantomData,
106            },
107            queries,
108            mutations,
109            subscriptions,
110            type_map,
111            phantom: PhantomData,
112        }
113    }
114
115    pub fn query<TResolver, TArg, TResult, TResultMarker>(
116        mut self,
117        key: &'static str,
118        builder: impl Fn(
119            UnbuiltProcedureBuilder<TLayerCtx, TResolver>,
120        ) -> BuiltProcedureBuilder<TResolver>,
121    ) -> Self
122    where
123        TArg: DeserializeOwned + Type,
124        TResult: RequestLayer<TResultMarker>,
125        TResolver: Fn(TLayerCtx, TArg) -> TResult + Send + Sync + 'static,
126    {
127        let resolver = builder(UnbuiltProcedureBuilder::default()).resolver;
128        self.queries.append(
129            key.into(),
130            self.middleware.build(ResolverLayer {
131                func: move |ctx, input, _| {
132                    resolver.exec(
133                        ctx,
134                        serde_json::from_value(input).map_err(ExecError::DeserializingArgErr)?,
135                    )
136                },
137                phantom: PhantomData,
138            }),
139            TResolver::typedef(&mut self.type_map),
140        );
141        self
142    }
143
144    pub fn mutation<TResolver, TArg, TResult, TResultMarker>(
145        mut self,
146        key: &'static str,
147        builder: impl Fn(
148            UnbuiltProcedureBuilder<TLayerCtx, TResolver>,
149        ) -> BuiltProcedureBuilder<TResolver>,
150    ) -> Self
151    where
152        TArg: DeserializeOwned + Type,
153        TResult: RequestLayer<TResultMarker>,
154        TResolver: Fn(TLayerCtx, TArg) -> TResult + Send + Sync + 'static,
155    {
156        let resolver = builder(UnbuiltProcedureBuilder::default()).resolver;
157        self.mutations.append(
158            key.into(),
159            self.middleware.build(ResolverLayer {
160                func: move |ctx, input, _| {
161                    resolver.exec(
162                        ctx,
163                        serde_json::from_value(input).map_err(ExecError::DeserializingArgErr)?,
164                    )
165                },
166                phantom: PhantomData,
167            }),
168            TResolver::typedef(&mut self.type_map),
169        );
170        self
171    }
172
173    pub fn subscription<TResolver, TArg, TStream, TResult, TResultMarker>(
174        mut self,
175        key: &'static str,
176        builder: impl Fn(
177            UnbuiltProcedureBuilder<TLayerCtx, TResolver>,
178        ) -> BuiltProcedureBuilder<TResolver>,
179    ) -> Self
180    where
181        TArg: DeserializeOwned + Type,
182        TStream: Stream<Item = TResult> + Send + 'static,
183        TResult: Serialize + Type,
184        TResolver: Fn(TLayerCtx, TArg) -> TStream
185            + StreamResolver<TLayerCtx, DoubleArgStreamMarker<TArg, TResultMarker, TStream>>
186            + Send
187            + Sync
188            + 'static,
189    {
190        let resolver = builder(UnbuiltProcedureBuilder::default()).resolver;
191        self.subscriptions.append(
192            key.into(),
193            self.middleware.build(ResolverLayer {
194                func: move |ctx, input, _| {
195                    resolver.exec(
196                        ctx,
197                        serde_json::from_value(input).map_err(ExecError::DeserializingArgErr)?,
198                    )
199                },
200                phantom: PhantomData,
201            }),
202            TResolver::typedef(&mut self.type_map),
203        );
204        self
205    }
206
207    pub fn merge<TNewLayerCtx, TIncomingMiddleware>(
208        mut self,
209        prefix: &'static str,
210        router: RouterBuilder<TLayerCtx, TMeta, TIncomingMiddleware>,
211    ) -> Self
212    where
213        TNewLayerCtx: 'static,
214        TIncomingMiddleware:
215            MiddlewareBuilderLike<TLayerCtx, LayerContext = TNewLayerCtx> + Send + 'static,
216    {
217        #[allow(clippy::panic)]
218        if prefix.is_empty() || prefix.starts_with("rpc.") || prefix.starts_with("rspc.") {
219            panic!(
220                "rspc error: attempted to merge a router with the prefix '{}', however this name is not allowed.",
221                prefix
222            );
223        }
224
225        // TODO: The `data` field has gotta flow from the root router to the leaf routers so that we don't have to merge user defined types.
226
227        for (key, query) in router.queries.store {
228            // query.ty.key = format!("{}{}", prefix, key);
229            self.queries.append(
230                format!("{}{}", prefix, key),
231                self.middleware.build(query.exec),
232                query.ty,
233            );
234        }
235
236        for (key, mutation) in router.mutations.store {
237            // mutation.ty.key = format!("{}{}", prefix, key);
238            self.mutations.append(
239                format!("{}{}", prefix, key),
240                self.middleware.build(mutation.exec),
241                mutation.ty,
242            );
243        }
244
245        for (key, subscription) in router.subscriptions.store {
246            // subscription.ty.key = format!("{}{}", prefix, key);
247            self.subscriptions.append(
248                format!("{}{}", prefix, key),
249                self.middleware.build(subscription.exec),
250                subscription.ty,
251            );
252        }
253
254        self.type_map.extend(&router.type_map);
255
256        self
257    }
258
259    /// `legacy_merge` maintains the `merge` functionality prior to release 0.1.3
260    /// It will flow the `TMiddleware` and `TCtx` out of the child router to the parent router.
261    /// This was a confusing behavior and is generally not useful so it has been deprecated.
262    ///
263    /// This function will be remove in a future release. If you are using it open a GitHub issue to discuss your use case and longer term solutions for it.
264    pub fn legacy_merge<TNewLayerCtx, TIncomingMiddleware>(
265        self,
266        prefix: &'static str,
267        router: RouterBuilder<TLayerCtx, TMeta, TIncomingMiddleware>,
268    ) -> RouterBuilder<
269        TCtx,
270        TMeta,
271        MiddlewareMerger<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TIncomingMiddleware>,
272    >
273    where
274        TNewLayerCtx: 'static,
275        TIncomingMiddleware:
276            MiddlewareBuilderLike<TLayerCtx, LayerContext = TNewLayerCtx> + Send + 'static,
277    {
278        #[allow(clippy::panic)]
279        if prefix.is_empty() || prefix.starts_with("rpc.") || prefix.starts_with("rspc.") {
280            panic!(
281                "rspc error: attempted to merge a router with the prefix '{}', however this name is not allowed.",
282                prefix
283            );
284        }
285
286        let Self {
287            config,
288            middleware,
289            mut queries,
290            mut mutations,
291            mut subscriptions,
292            mut type_map,
293            ..
294        } = self;
295
296        for (key, query) in router.queries.store {
297            queries.append(
298                format!("{}{}", prefix, key),
299                middleware.build(query.exec),
300                query.ty,
301            );
302        }
303
304        for (key, mutation) in router.mutations.store {
305            mutations.append(
306                format!("{}{}", prefix, key),
307                middleware.build(mutation.exec),
308                mutation.ty,
309            );
310        }
311
312        for (key, subscription) in router.subscriptions.store {
313            subscriptions.append(
314                format!("{}{}", prefix, key),
315                middleware.build(subscription.exec),
316                subscription.ty,
317            );
318        }
319
320        type_map.extend(&router.type_map);
321
322        RouterBuilder {
323            config,
324            middleware: MiddlewareMerger {
325                middleware,
326                middleware2: router.middleware,
327                phantom: PhantomData,
328            },
329            queries,
330            mutations,
331            subscriptions,
332            type_map,
333            phantom: PhantomData,
334        }
335    }
336
337    pub fn build(self) -> Router<TCtx, TMeta> {
338        let Self {
339            config,
340            queries,
341            mutations,
342            subscriptions,
343            type_map,
344            ..
345        } = self;
346
347        let export_path = config.export_bindings_on_build.clone();
348        let router = Router {
349            config,
350            queries,
351            mutations,
352            subscriptions,
353            type_map,
354            phantom: PhantomData,
355        };
356
357        #[cfg(debug_assertions)]
358        #[allow(clippy::unwrap_used)]
359        if let Some(export_path) = export_path {
360            router.export_ts(export_path).unwrap();
361        }
362
363        router
364    }
365}