qubit/server/
router.rs

1use std::collections::HashMap;
2use std::{collections::HashSet, convert::Infallible, fmt::Write as _, fs, path::Path};
3
4use axum::body::Body;
5use futures::FutureExt;
6use http::{HeaderValue, Method, Request};
7use jsonrpsee::server::ws::is_upgrade_request;
8pub use jsonrpsee::server::ServerHandle;
9use jsonrpsee::RpcModule;
10use tower::service_fn;
11use tower::Service;
12use tower::ServiceBuilder;
13
14use crate::builder::*;
15use crate::RequestKind;
16
17/// Router for the RPC server. Can have different handlers attached to it, as well as nested
18/// routers in order to create a hierarchy. It is also capable of generating its own type, suitable
19/// for consumption by a TypeScript client.
20#[derive(Clone)]
21pub struct Router<Ctx> {
22    nested_routers: Vec<(&'static str, Router<Ctx>)>,
23    handlers: Vec<HandlerCallbacks<Ctx>>,
24}
25
26impl<Ctx> Router<Ctx>
27where
28    Ctx: Clone + Send + Sync + 'static,
29{
30    /// Create a new instance of the router.
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Attach a handler to the router.
36    pub fn handler<H: Handler<Ctx>>(mut self, handler: H) -> Self {
37        self.handlers.push(HandlerCallbacks::from_handler(handler));
38
39        self
40    }
41
42    /// Nest another router within this router, under the provided namespace.
43    pub fn nest(mut self, namespace: &'static str, router: Router<Ctx>) -> Self {
44        self.nested_routers.push((namespace, router));
45
46        self
47    }
48
49    /// Write required bindings for this router the the provided directory. The directory will be
50    /// cleared, so anything within will be lost.
51    pub fn write_bindings_to_dir(&self, out_dir: impl AsRef<Path>) {
52        let out_dir = out_dir.as_ref();
53
54        // Make sure the directory path exists
55        fs::create_dir_all(out_dir).unwrap();
56
57        // Clear the directiry
58        fs::remove_dir_all(out_dir).unwrap();
59
60        // Re-create the directory
61        fs::create_dir_all(out_dir).unwrap();
62
63        let header = String::from(include_str!("../header.txt"));
64
65        // Export all the dependencies, and create their import statements
66        let (imports, exports, _types) = self
67            .get_handlers()
68            .into_iter()
69            .flat_map(|handler| {
70                (handler.export_all_dependencies_to)(out_dir)
71                    .unwrap()
72                    .into_iter()
73                    .map(|dep| {
74                        (
75                            format!("./{}", dep.output_path.to_str().unwrap()),
76                            dep.ts_name,
77                        )
78                    })
79                    .chain((handler.qubit_types)().into_iter().map(|ty| ty.to_ts()))
80            })
81            .fold(
82                (String::new(), String::new(), HashSet::new()),
83                |(mut imports, mut exports, mut types), ty| {
84                    if types.contains(&ty) {
85                        return (imports, exports, types);
86                    }
87
88                    let (package, ty_name) = ty;
89
90                    writeln!(
91                        &mut imports,
92                        r#"import type {{ {ty_name} }} from "{package}";"#,
93                    )
94                    .unwrap();
95
96                    writeln!(
97                        &mut exports,
98                        r#"export type {{ {ty_name} }} from "{package}";"#,
99                    )
100                    .unwrap();
101
102                    types.insert((package, ty_name));
103
104                    (imports, exports, types)
105                },
106            );
107
108        // Generate server type
109        let server_type = format!("export type QubitServer = {};", self.get_type());
110
111        // Write out index file
112        fs::write(
113            out_dir.join("index.ts"),
114            [header, imports, exports, server_type]
115                .into_iter()
116                .filter(|part| !part.is_empty())
117                .collect::<Vec<_>>()
118                .join("\n"),
119        )
120        .unwrap();
121    }
122
123    /// Turn the router into a [`tower::Service`], so that it can be nested into a HTTP server.
124    /// The provided `ctx` will be cloned for each request.
125    pub fn to_service(
126        self,
127        ctx: Ctx,
128    ) -> (
129        impl Service<
130                hyper::Request<axum::body::Body>,
131                Response = jsonrpsee::server::HttpResponse,
132                Error = Infallible,
133                Future = impl Send,
134            > + Clone,
135        ServerHandle,
136    ) {
137        // Generate the stop and server handles for the service
138        let (stop_handle, server_handle) = jsonrpsee::server::stop_channel();
139
140        // Build out the RPC module into a service
141        let mut service = jsonrpsee::server::Server::builder()
142            .set_http_middleware(ServiceBuilder::new().map_request(|mut req: Request<_>| {
143                // Check if this is a GET request, and if it is convert it to a regular POST
144                let request_type = if matches!(req.method(), &Method::GET)
145                    && !is_upgrade_request(&req)
146                {
147                    // Change this request into a regular POST request, and indicate that it should
148                    // be a query.
149                    *req.method_mut() = Method::POST;
150
151                    // Update the headers
152                    let headers = req.headers_mut();
153                    headers.insert(
154                        hyper::header::CONTENT_TYPE,
155                        HeaderValue::from_static("application/json"),
156                    );
157                    headers.insert(
158                        hyper::header::ACCEPT,
159                        HeaderValue::from_static("application/json"),
160                    );
161
162                    // Convert the `input` field of the query string into the request body
163                    if let Some(body) = req
164                        // Extract the query string
165                        .uri()
166                        .query()
167                        // Parse the query string
168                        .and_then(|query| serde_qs::from_str::<HashMap<String, String>>(query).ok())
169                        // Take out the input
170                        .and_then(|mut query| query.remove("input"))
171                        // URL decode the input
172                        .map(|input| urlencoding::decode(&input).unwrap_or_default().to_string())
173                    {
174                        // Set the request body
175                        *req.body_mut() = Body::from(body);
176                    }
177
178                    RequestKind::Query
179                } else {
180                    RequestKind::Any
181                };
182
183                // Set the request kind
184                req.extensions_mut().insert(request_type);
185
186                req
187            }))
188            .to_service_builder()
189            .build(self.build_rpc_module(ctx, None), stop_handle);
190
191        (
192            service_fn(move |req: hyper::Request<axum::body::Body>| {
193                let call = service.call(req);
194
195                async move {
196                    match call.await {
197                        Ok(response) => Ok::<_, Infallible>(response),
198                        Err(_) => unreachable!(),
199                    }
200                }
201                .boxed()
202            }),
203            server_handle,
204        )
205    }
206
207    /// Get the TypeScript type of this router.
208    fn get_type(&self) -> String {
209        // Generate types of all handlers, including nested handlers
210        let handlers = self
211            .handlers
212            .iter()
213            // Generate types of handlers
214            .map(|handler| {
215                let handler_type = (handler.get_type)();
216                format!("{}: {}", handler_type.name, handler_type.signature)
217            })
218            .chain(
219                // Generate types of nested routers
220                self.nested_routers.iter().map(|(namespace, router)| {
221                    let router_type = router.get_type();
222                    format!("{namespace}: {router_type}")
223                }),
224            )
225            .collect::<Vec<_>>();
226
227        // Generate the router type
228        format!("{{ {} }}", handlers.join(", "))
229    }
230
231    /// Generate a [`jsonrpsee::RpcModule`] from this router, with an optional namespace.
232    ///
233    /// Uses an [`RpcBuilder`] to incrementally add query and subcription handlers, passing the
234    /// instance through to the [`HandlerCallbacks`] attached to this router, so they can register
235    /// against the [`RpcModule`] (including namespacing).
236    fn build_rpc_module(self, ctx: Ctx, namespace: Option<&'static str>) -> RpcModule<Ctx> {
237        let rpc_module = self
238            .handlers
239            .into_iter()
240            .fold(
241                RpcBuilder::with_namespace(ctx.clone(), namespace),
242                |rpc_builder, handler| (handler.register)(rpc_builder),
243            )
244            .build();
245
246        // Generate modules for nested routers, and merge them with the existing router
247        let parent_namespace = namespace;
248        self.nested_routers
249            .into_iter()
250            .fold(rpc_module, |mut rpc_module, (namespace, router)| {
251                let namespace = if let Some(parent_namespace) = parent_namespace {
252                    // WARN: Probably not great leaking here
253                    format!("{parent_namespace}.{namespace}").leak()
254                } else {
255                    namespace
256                };
257
258                rpc_module
259                    .merge(router.build_rpc_module(ctx.clone(), Some(namespace)))
260                    .unwrap();
261
262                rpc_module
263            })
264    }
265
266    fn get_handlers(&self) -> Vec<HandlerCallbacks<Ctx>> {
267        self.handlers
268            .iter()
269            .cloned()
270            .chain(
271                self.nested_routers
272                    .iter()
273                    .flat_map(|(_, router)| router.get_handlers()),
274            )
275            .collect()
276    }
277}
278
279impl<Ctx> Default for Router<Ctx> {
280    fn default() -> Self {
281        Self {
282            nested_routers: Default::default(),
283            handlers: Default::default(),
284        }
285    }
286}