1use std::collections::HashMap;
2use std::{collections::HashSet, convert::Infallible, fmt::Write as _, fs, path::Path};
3
4use axum::body::Body;
5use futures::FutureExt;
6use http::{HeaderValue, Method, Request};
7use jsonrpsee::server::ws::is_upgrade_request;
8pub use jsonrpsee::server::ServerHandle;
9use jsonrpsee::RpcModule;
10use tower::service_fn;
11use tower::Service;
12use tower::ServiceBuilder;
13
14use crate::builder::*;
15use crate::RequestKind;
16
17#[derive(Clone)]
21pub struct Router<Ctx> {
22 nested_routers: Vec<(&'static str, Router<Ctx>)>,
23 handlers: Vec<HandlerCallbacks<Ctx>>,
24}
25
26impl<Ctx> Router<Ctx>
27where
28 Ctx: Clone + Send + Sync + 'static,
29{
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn handler<H: Handler<Ctx>>(mut self, handler: H) -> Self {
37 self.handlers.push(HandlerCallbacks::from_handler(handler));
38
39 self
40 }
41
42 pub fn nest(mut self, namespace: &'static str, router: Router<Ctx>) -> Self {
44 self.nested_routers.push((namespace, router));
45
46 self
47 }
48
49 pub fn write_bindings_to_dir(&self, out_dir: impl AsRef<Path>) {
52 let out_dir = out_dir.as_ref();
53
54 fs::create_dir_all(out_dir).unwrap();
56
57 fs::remove_dir_all(out_dir).unwrap();
59
60 fs::create_dir_all(out_dir).unwrap();
62
63 let header = String::from(include_str!("../header.txt"));
64
65 let (imports, exports, _types) = self
67 .get_handlers()
68 .into_iter()
69 .flat_map(|handler| {
70 (handler.export_all_dependencies_to)(out_dir)
71 .unwrap()
72 .into_iter()
73 .map(|dep| {
74 (
75 format!("./{}", dep.output_path.to_str().unwrap()),
76 dep.ts_name,
77 )
78 })
79 .chain((handler.qubit_types)().into_iter().map(|ty| ty.to_ts()))
80 })
81 .fold(
82 (String::new(), String::new(), HashSet::new()),
83 |(mut imports, mut exports, mut types), ty| {
84 if types.contains(&ty) {
85 return (imports, exports, types);
86 }
87
88 let (package, ty_name) = ty;
89
90 writeln!(
91 &mut imports,
92 r#"import type {{ {ty_name} }} from "{package}";"#,
93 )
94 .unwrap();
95
96 writeln!(
97 &mut exports,
98 r#"export type {{ {ty_name} }} from "{package}";"#,
99 )
100 .unwrap();
101
102 types.insert((package, ty_name));
103
104 (imports, exports, types)
105 },
106 );
107
108 let server_type = format!("export type QubitServer = {};", self.get_type());
110
111 fs::write(
113 out_dir.join("index.ts"),
114 [header, imports, exports, server_type]
115 .into_iter()
116 .filter(|part| !part.is_empty())
117 .collect::<Vec<_>>()
118 .join("\n"),
119 )
120 .unwrap();
121 }
122
123 pub fn to_service(
126 self,
127 ctx: Ctx,
128 ) -> (
129 impl Service<
130 hyper::Request<axum::body::Body>,
131 Response = jsonrpsee::server::HttpResponse,
132 Error = Infallible,
133 Future = impl Send,
134 > + Clone,
135 ServerHandle,
136 ) {
137 let (stop_handle, server_handle) = jsonrpsee::server::stop_channel();
139
140 let mut service = jsonrpsee::server::Server::builder()
142 .set_http_middleware(ServiceBuilder::new().map_request(|mut req: Request<_>| {
143 let request_type = if matches!(req.method(), &Method::GET)
145 && !is_upgrade_request(&req)
146 {
147 *req.method_mut() = Method::POST;
150
151 let headers = req.headers_mut();
153 headers.insert(
154 hyper::header::CONTENT_TYPE,
155 HeaderValue::from_static("application/json"),
156 );
157 headers.insert(
158 hyper::header::ACCEPT,
159 HeaderValue::from_static("application/json"),
160 );
161
162 if let Some(body) = req
164 .uri()
166 .query()
167 .and_then(|query| serde_qs::from_str::<HashMap<String, String>>(query).ok())
169 .and_then(|mut query| query.remove("input"))
171 .map(|input| urlencoding::decode(&input).unwrap_or_default().to_string())
173 {
174 *req.body_mut() = Body::from(body);
176 }
177
178 RequestKind::Query
179 } else {
180 RequestKind::Any
181 };
182
183 req.extensions_mut().insert(request_type);
185
186 req
187 }))
188 .to_service_builder()
189 .build(self.build_rpc_module(ctx, None), stop_handle);
190
191 (
192 service_fn(move |req: hyper::Request<axum::body::Body>| {
193 let call = service.call(req);
194
195 async move {
196 match call.await {
197 Ok(response) => Ok::<_, Infallible>(response),
198 Err(_) => unreachable!(),
199 }
200 }
201 .boxed()
202 }),
203 server_handle,
204 )
205 }
206
207 fn get_type(&self) -> String {
209 let handlers = self
211 .handlers
212 .iter()
213 .map(|handler| {
215 let handler_type = (handler.get_type)();
216 format!("{}: {}", handler_type.name, handler_type.signature)
217 })
218 .chain(
219 self.nested_routers.iter().map(|(namespace, router)| {
221 let router_type = router.get_type();
222 format!("{namespace}: {router_type}")
223 }),
224 )
225 .collect::<Vec<_>>();
226
227 format!("{{ {} }}", handlers.join(", "))
229 }
230
231 fn build_rpc_module(self, ctx: Ctx, namespace: Option<&'static str>) -> RpcModule<Ctx> {
237 let rpc_module = self
238 .handlers
239 .into_iter()
240 .fold(
241 RpcBuilder::with_namespace(ctx.clone(), namespace),
242 |rpc_builder, handler| (handler.register)(rpc_builder),
243 )
244 .build();
245
246 let parent_namespace = namespace;
248 self.nested_routers
249 .into_iter()
250 .fold(rpc_module, |mut rpc_module, (namespace, router)| {
251 let namespace = if let Some(parent_namespace) = parent_namespace {
252 format!("{parent_namespace}.{namespace}").leak()
254 } else {
255 namespace
256 };
257
258 rpc_module
259 .merge(router.build_rpc_module(ctx.clone(), Some(namespace)))
260 .unwrap();
261
262 rpc_module
263 })
264 }
265
266 fn get_handlers(&self) -> Vec<HandlerCallbacks<Ctx>> {
267 self.handlers
268 .iter()
269 .cloned()
270 .chain(
271 self.nested_routers
272 .iter()
273 .flat_map(|(_, router)| router.get_handlers()),
274 )
275 .collect()
276 }
277}
278
279impl<Ctx> Default for Router<Ctx> {
280 fn default() -> Self {
281 Self {
282 nested_routers: Default::default(),
283 handlers: Default::default(),
284 }
285 }
286}