1use std::marker::PhantomData;
2
3use futures::Stream;
4use serde::{de::DeserializeOwned, Serialize};
5use specta::{Type, TypeCollection};
6
7use crate::{
8 internal::{
9 BaseMiddleware, BuiltProcedureBuilder, MiddlewareBuilderLike, MiddlewareLayerBuilder,
10 MiddlewareMerger, ProcedureStore, ResolverLayer, UnbuiltProcedureBuilder,
11 },
12 Config, DoubleArgStreamMarker, ExecError, MiddlewareBuilder, MiddlewareLike, RequestLayer,
13 Resolver, Router, StreamResolver,
14};
15
16pub struct RouterBuilder<
17 TCtx = (), TMeta = (),
19 TMiddleware = BaseMiddleware<TCtx>,
20> where
21 TCtx: Send + Sync + 'static,
22 TMeta: Send + 'static,
23 TMiddleware: MiddlewareBuilderLike<TCtx> + Send + 'static,
24{
25 config: Config,
26 middleware: TMiddleware,
27 queries: ProcedureStore<TCtx>,
28 mutations: ProcedureStore<TCtx>,
29 subscriptions: ProcedureStore<TCtx>,
30 type_map: TypeCollection,
31 phantom: PhantomData<TMeta>,
32}
33
34#[allow(clippy::new_without_default, clippy::new_ret_no_self)]
35impl<TCtx, TMeta> Router<TCtx, TMeta>
36where
37 TCtx: Send + Sync + 'static,
38 TMeta: Send + 'static,
39{
40 pub fn new() -> RouterBuilder<TCtx, TMeta, BaseMiddleware<TCtx>> {
41 RouterBuilder::new()
42 }
43}
44
45#[allow(clippy::new_without_default)]
46impl<TCtx, TMeta> RouterBuilder<TCtx, TMeta, BaseMiddleware<TCtx>>
47where
48 TCtx: Send + Sync + 'static,
49 TMeta: Send + 'static,
50{
51 pub fn new() -> Self {
52 Self {
53 config: Config::new(),
54 middleware: BaseMiddleware::default(),
55 queries: ProcedureStore::new("query"),
56 mutations: ProcedureStore::new("mutation"),
57 subscriptions: ProcedureStore::new("subscription"),
58 type_map: Default::default(),
59 phantom: PhantomData,
60 }
61 }
62}
63
64impl<TCtx, TLayerCtx, TMeta, TMiddleware> RouterBuilder<TCtx, TMeta, TMiddleware>
65where
66 TCtx: Send + Sync + 'static,
67 TMeta: Send + 'static,
68 TLayerCtx: Send + Sync + 'static,
69 TMiddleware: MiddlewareBuilderLike<TCtx, LayerContext = TLayerCtx> + Send + 'static,
70{
71 pub fn config(mut self, config: Config) -> Self {
73 self.config = config;
74 self
75 }
76
77 pub fn middleware<TNewMiddleware, TNewLayerCtx>(
78 self,
79 builder: impl Fn(MiddlewareBuilder<TLayerCtx>) -> TNewMiddleware,
80 ) -> RouterBuilder<
81 TCtx,
82 TMeta,
83 MiddlewareLayerBuilder<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TNewMiddleware>,
84 >
85 where
86 TNewLayerCtx: Send + Sync + 'static,
87 TNewMiddleware: MiddlewareLike<TLayerCtx, NewCtx = TNewLayerCtx> + Send + Sync + 'static,
88 {
89 let Self {
90 config,
91 middleware,
92 queries,
93 mutations,
94 subscriptions,
95 type_map,
96 ..
97 } = self;
98
99 let mw = builder(MiddlewareBuilder(PhantomData));
100 RouterBuilder {
101 config,
102 middleware: MiddlewareLayerBuilder {
103 middleware,
104 mw,
105 phantom: PhantomData,
106 },
107 queries,
108 mutations,
109 subscriptions,
110 type_map,
111 phantom: PhantomData,
112 }
113 }
114
115 pub fn query<TResolver, TArg, TResult, TResultMarker>(
116 mut self,
117 key: &'static str,
118 builder: impl Fn(
119 UnbuiltProcedureBuilder<TLayerCtx, TResolver>,
120 ) -> BuiltProcedureBuilder<TResolver>,
121 ) -> Self
122 where
123 TArg: DeserializeOwned + Type,
124 TResult: RequestLayer<TResultMarker>,
125 TResolver: Fn(TLayerCtx, TArg) -> TResult + Send + Sync + 'static,
126 {
127 let resolver = builder(UnbuiltProcedureBuilder::default()).resolver;
128 self.queries.append(
129 key.into(),
130 self.middleware.build(ResolverLayer {
131 func: move |ctx, input, _| {
132 resolver.exec(
133 ctx,
134 serde_json::from_value(input).map_err(ExecError::DeserializingArgErr)?,
135 )
136 },
137 phantom: PhantomData,
138 }),
139 TResolver::typedef(&mut self.type_map),
140 );
141 self
142 }
143
144 pub fn mutation<TResolver, TArg, TResult, TResultMarker>(
145 mut self,
146 key: &'static str,
147 builder: impl Fn(
148 UnbuiltProcedureBuilder<TLayerCtx, TResolver>,
149 ) -> BuiltProcedureBuilder<TResolver>,
150 ) -> Self
151 where
152 TArg: DeserializeOwned + Type,
153 TResult: RequestLayer<TResultMarker>,
154 TResolver: Fn(TLayerCtx, TArg) -> TResult + Send + Sync + 'static,
155 {
156 let resolver = builder(UnbuiltProcedureBuilder::default()).resolver;
157 self.mutations.append(
158 key.into(),
159 self.middleware.build(ResolverLayer {
160 func: move |ctx, input, _| {
161 resolver.exec(
162 ctx,
163 serde_json::from_value(input).map_err(ExecError::DeserializingArgErr)?,
164 )
165 },
166 phantom: PhantomData,
167 }),
168 TResolver::typedef(&mut self.type_map),
169 );
170 self
171 }
172
173 pub fn subscription<TResolver, TArg, TStream, TResult, TResultMarker>(
174 mut self,
175 key: &'static str,
176 builder: impl Fn(
177 UnbuiltProcedureBuilder<TLayerCtx, TResolver>,
178 ) -> BuiltProcedureBuilder<TResolver>,
179 ) -> Self
180 where
181 TArg: DeserializeOwned + Type,
182 TStream: Stream<Item = TResult> + Send + 'static,
183 TResult: Serialize + Type,
184 TResolver: Fn(TLayerCtx, TArg) -> TStream
185 + StreamResolver<TLayerCtx, DoubleArgStreamMarker<TArg, TResultMarker, TStream>>
186 + Send
187 + Sync
188 + 'static,
189 {
190 let resolver = builder(UnbuiltProcedureBuilder::default()).resolver;
191 self.subscriptions.append(
192 key.into(),
193 self.middleware.build(ResolverLayer {
194 func: move |ctx, input, _| {
195 resolver.exec(
196 ctx,
197 serde_json::from_value(input).map_err(ExecError::DeserializingArgErr)?,
198 )
199 },
200 phantom: PhantomData,
201 }),
202 TResolver::typedef(&mut self.type_map),
203 );
204 self
205 }
206
207 pub fn merge<TNewLayerCtx, TIncomingMiddleware>(
208 mut self,
209 prefix: &'static str,
210 router: RouterBuilder<TLayerCtx, TMeta, TIncomingMiddleware>,
211 ) -> Self
212 where
213 TNewLayerCtx: 'static,
214 TIncomingMiddleware:
215 MiddlewareBuilderLike<TLayerCtx, LayerContext = TNewLayerCtx> + Send + 'static,
216 {
217 #[allow(clippy::panic)]
218 if prefix.is_empty() || prefix.starts_with("rpc.") || prefix.starts_with("rspc.") {
219 panic!(
220 "rspc error: attempted to merge a router with the prefix '{}', however this name is not allowed.",
221 prefix
222 );
223 }
224
225 for (key, query) in router.queries.store {
228 self.queries.append(
230 format!("{}{}", prefix, key),
231 self.middleware.build(query.exec),
232 query.ty,
233 );
234 }
235
236 for (key, mutation) in router.mutations.store {
237 self.mutations.append(
239 format!("{}{}", prefix, key),
240 self.middleware.build(mutation.exec),
241 mutation.ty,
242 );
243 }
244
245 for (key, subscription) in router.subscriptions.store {
246 self.subscriptions.append(
248 format!("{}{}", prefix, key),
249 self.middleware.build(subscription.exec),
250 subscription.ty,
251 );
252 }
253
254 self.type_map.extend(&router.type_map);
255
256 self
257 }
258
259 pub fn legacy_merge<TNewLayerCtx, TIncomingMiddleware>(
265 self,
266 prefix: &'static str,
267 router: RouterBuilder<TLayerCtx, TMeta, TIncomingMiddleware>,
268 ) -> RouterBuilder<
269 TCtx,
270 TMeta,
271 MiddlewareMerger<TCtx, TLayerCtx, TNewLayerCtx, TMiddleware, TIncomingMiddleware>,
272 >
273 where
274 TNewLayerCtx: 'static,
275 TIncomingMiddleware:
276 MiddlewareBuilderLike<TLayerCtx, LayerContext = TNewLayerCtx> + Send + 'static,
277 {
278 #[allow(clippy::panic)]
279 if prefix.is_empty() || prefix.starts_with("rpc.") || prefix.starts_with("rspc.") {
280 panic!(
281 "rspc error: attempted to merge a router with the prefix '{}', however this name is not allowed.",
282 prefix
283 );
284 }
285
286 let Self {
287 config,
288 middleware,
289 mut queries,
290 mut mutations,
291 mut subscriptions,
292 mut type_map,
293 ..
294 } = self;
295
296 for (key, query) in router.queries.store {
297 queries.append(
298 format!("{}{}", prefix, key),
299 middleware.build(query.exec),
300 query.ty,
301 );
302 }
303
304 for (key, mutation) in router.mutations.store {
305 mutations.append(
306 format!("{}{}", prefix, key),
307 middleware.build(mutation.exec),
308 mutation.ty,
309 );
310 }
311
312 for (key, subscription) in router.subscriptions.store {
313 subscriptions.append(
314 format!("{}{}", prefix, key),
315 middleware.build(subscription.exec),
316 subscription.ty,
317 );
318 }
319
320 type_map.extend(&router.type_map);
321
322 RouterBuilder {
323 config,
324 middleware: MiddlewareMerger {
325 middleware,
326 middleware2: router.middleware,
327 phantom: PhantomData,
328 },
329 queries,
330 mutations,
331 subscriptions,
332 type_map,
333 phantom: PhantomData,
334 }
335 }
336
337 pub fn build(self) -> Router<TCtx, TMeta> {
338 let Self {
339 config,
340 queries,
341 mutations,
342 subscriptions,
343 type_map,
344 ..
345 } = self;
346
347 let export_path = config.export_bindings_on_build.clone();
348 let router = Router {
349 config,
350 queries,
351 mutations,
352 subscriptions,
353 type_map,
354 phantom: PhantomData,
355 };
356
357 #[cfg(debug_assertions)]
358 #[allow(clippy::unwrap_used)]
359 if let Some(export_path) = export_path {
360 router.export_ts(export_path).unwrap();
361 }
362
363 router
364 }
365}