okapi_operation/axum_integration/
router.rs1use std::{collections::HashMap, convert::Infallible, fmt};
2
3use axum::{
4 Router as AxumRouter, extract::Request, handler::Handler, http::Method, response::IntoResponse,
5 routing::Route,
6};
7use tower::{Layer, Service};
8
9use super::{
10 get,
11 method_router::{MethodRouter, MethodRouterOperations},
12 operations::RoutesOperations,
13};
14use crate::OpenApiBuilder;
15
16pub const DEFAULT_OPENAPI_PATH: &str = "/openapi";
17
18pub struct Router<S = ()> {
24 axum_router: AxumRouter<S>,
25 routes_operations_map: HashMap<String, MethodRouterOperations>,
26 openapi_builder_template: OpenApiBuilder,
27}
28
29impl<S> From<AxumRouter<S>> for Router<S> {
30 fn from(value: AxumRouter<S>) -> Self {
31 Self {
32 axum_router: value,
33 routes_operations_map: Default::default(),
34 openapi_builder_template: OpenApiBuilder::default(),
35 }
36 }
37}
38
39impl<S> Clone for Router<S>
40where
41 S: Clone + Send + Sync + 'static,
42{
43 fn clone(&self) -> Self {
44 Self {
45 axum_router: self.axum_router.clone(),
46 routes_operations_map: self.routes_operations_map.clone(),
47 openapi_builder_template: self.openapi_builder_template.clone(),
48 }
49 }
50}
51
52impl<S> Default for Router<S>
53where
54 S: Clone + Send + Sync + 'static,
55{
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl<S> fmt::Debug for Router<S>
62where
63 S: fmt::Debug,
64{
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 self.axum_router.fmt(f)
67 }
68}
69
70impl<S> Router<S>
71where
72 S: Clone + Send + Sync + 'static,
73{
74 pub fn new() -> Self {
76 Self {
77 axum_router: AxumRouter::new(),
78 routes_operations_map: HashMap::new(),
79 openapi_builder_template: OpenApiBuilder::default(),
80 }
81 }
82
83 pub fn route<R>(mut self, path: &str, method_router: R) -> Self
104 where
105 R: Into<MethodRouter<S>>,
106 {
107 let method_router = method_router.into();
108
109 let s = self.routes_operations_map.entry(path.into()).or_default();
111 *s = s.clone().merge(method_router.operations);
112
113 Self {
114 axum_router: self
115 .axum_router
116 .route(path, method_router.axum_method_router),
117 ..self
118 }
119 }
120
121 pub fn route_service<Svc>(self, path: &str, service: Svc) -> Self
129 where
130 Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
131 Svc::Response: IntoResponse,
132 Svc::Future: Send + 'static,
133 {
134 Self {
135 axum_router: self.axum_router.route_service(path, service),
136 ..self
137 }
138 }
139
140 pub fn nest<R>(mut self, path: &str, router: R) -> Self
160 where
161 R: Into<Router<S>>,
162 {
163 let router = router.into();
164 for (inner_path, operation) in router.routes_operations_map.into_iter() {
165 let _ = self
166 .routes_operations_map
167 .insert(format!("{}{}", path, inner_path), operation);
168 }
169 Self {
170 axum_router: self.axum_router.nest(path, router.axum_router),
171 ..self
172 }
173 }
174
175 pub fn nest_service<Svc>(self, path: &str, svc: Svc) -> Self
179 where
180 Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
181 Svc::Response: IntoResponse,
182 Svc::Future: Send + 'static,
183 {
184 Self {
185 axum_router: self.axum_router.nest_service(path, svc),
186 ..self
187 }
188 }
189
190 pub fn merge<R>(mut self, other: R) -> Self
211 where
212 R: Into<Router<S>>,
213 {
214 let other = other.into();
215 self.routes_operations_map
216 .extend(other.routes_operations_map);
217 Self {
218 axum_router: self.axum_router.merge(other.axum_router),
219 ..self
220 }
221 }
222
223 pub fn layer<L>(self, layer: L) -> Router<S>
227 where
228 L: Layer<Route> + Clone + Send + Sync + 'static,
229 L::Service: Service<Request> + Clone + Send + Sync + 'static,
230 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
231 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
232 <L::Service as Service<Request>>::Future: Send + 'static,
233 {
234 Router {
235 axum_router: self.axum_router.layer(layer),
236 routes_operations_map: self.routes_operations_map,
237 openapi_builder_template: self.openapi_builder_template,
238 }
239 }
240
241 pub fn route_layer<L>(self, layer: L) -> Self
245 where
246 L: Layer<Route> + Clone + Send + Sync + 'static,
247 L::Service: Service<Request> + Clone + Send + Sync + 'static,
248 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
249 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
250 <L::Service as Service<Request>>::Future: Send + 'static,
251 {
252 Router {
253 axum_router: self.axum_router.route_layer(layer),
254 routes_operations_map: self.routes_operations_map,
255 openapi_builder_template: self.openapi_builder_template,
256 }
257 }
258
259 pub fn fallback<H, T>(self, handler: H) -> Self
268 where
269 H: Handler<T, S>,
270 T: 'static,
271 {
272 Router {
273 axum_router: self.axum_router.fallback(handler),
274 ..self
275 }
276 }
277
278 pub fn fallback_service<Svc>(self, svc: Svc) -> Self
286 where
287 Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
288 Svc::Response: IntoResponse,
289 Svc::Future: Send + 'static,
290 {
291 Router {
292 axum_router: self.axum_router.fallback_service(svc),
293 ..self
294 }
295 }
296
297 pub fn with_state<S2>(self, state: S) -> Router<S2> {
301 Router {
302 axum_router: self.axum_router.with_state(state),
303 routes_operations_map: self.routes_operations_map,
304 openapi_builder_template: self.openapi_builder_template,
305 }
306 }
307
308 pub fn into_parts(self) -> (AxumRouter<S>, RoutesOperations) {
310 (
311 self.axum_router,
312 RoutesOperations::new(self.routes_operations_map),
313 )
314 }
315
316 pub fn axum_router(&self) -> AxumRouter<S> {
318 self.axum_router.clone()
319 }
320
321 pub fn routes_operations(&self) -> RoutesOperations {
323 RoutesOperations::new(self.routes_operations_map.clone())
324 }
325
326 pub fn generate_openapi_builder(&self) -> OpenApiBuilder {
333 let routes = self.routes_operations().openapi_operation_generators();
334 let mut builder = self.openapi_builder_template.clone();
335 builder.operations(routes.into_iter().map(|((x, y), z)| (x, y, z)));
338 builder
339 }
340
341 pub fn set_openapi_builder_template(&mut self, builder: OpenApiBuilder) -> &mut Self {
345 self.openapi_builder_template = builder;
346 self
347 }
348
349 pub fn update_openapi_builder_template<F>(&mut self, f: F) -> &mut Self
353 where
354 F: FnOnce(&mut OpenApiBuilder),
355 {
356 f(&mut self.openapi_builder_template);
357 self
358 }
359
360 pub fn openapi_builder_template_mut(&mut self) -> &mut OpenApiBuilder {
364 &mut self.openapi_builder_template
365 }
366
367 pub fn finish_openapi<'a>(
391 mut self,
392 serve_path: impl Into<Option<&'a str>>,
393 title: impl Into<String>,
394 version: impl Into<String>,
395 ) -> Result<AxumRouter<S>, anyhow::Error> {
396 let serve_path = serve_path.into().unwrap_or(DEFAULT_OPENAPI_PATH);
397
398 let spec = self
401 .generate_openapi_builder()
402 .operation(serve_path, Method::GET, super::serve_openapi_spec__openapi)
403 .title(title)
404 .version(version)
405 .build()?;
406
407 self = self.route(serve_path, get(super::serve_openapi_spec).with_state(spec));
408
409 Ok(self.axum_router)
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 #![allow(clippy::let_underscore_future)]
416
417 use axum::{http::Method, routing::get as axum_get};
418 use okapi::openapi3::Operation;
419 use tokio::net::TcpListener;
420
421 use super::*;
422 use crate::{
423 BuilderOptions, Components,
424 axum_integration::{HandlerExt, get, post},
425 };
426
427 fn openapi_generator(
428 _: &mut Components,
429 _: &BuilderOptions,
430 ) -> Result<Operation, anyhow::Error> {
431 unimplemented!()
432 }
433
434 #[test]
435 fn mount_axum_types() {
436 let axum_router = AxumRouter::new().route("/get", axum_get(|| async {}));
437 let (app, meta) = Router::new()
438 .route("/", axum_get(|| async {}))
439 .nest("/nested", axum_router.clone())
440 .merge(axum_router)
441 .into_parts();
442 assert!(meta.0.is_empty());
443 let make_service = app.into_make_service();
444 let _ = async move {
445 let listener = TcpListener::bind("").await.unwrap();
446 axum::serve(listener, make_service).await.unwrap()
447 };
448 }
449
450 #[test]
451 fn mount() {
452 let router = Router::new().route("/get", get(|| async {})).route(
453 "/get_with_spec",
454 get((|| async {}).with_openapi(openapi_generator)),
455 );
456 let router2 = Router::new().route("/get", get(|| async {})).route(
457 "/get_with_spec",
458 get((|| async {}).with_openapi(openapi_generator)),
459 );
460 let (app, ops) = Router::new()
461 .route("/", get(|| async {}))
462 .nest("/nested", router)
463 .merge(router2)
464 .route(
465 "/my_path",
466 get((|| async {}).with_openapi(openapi_generator)),
467 )
468 .route(
469 "/my_path",
470 post((|| async {}).with_openapi(openapi_generator)),
471 )
472 .into_parts();
473
474 assert!(ops.get_path("/").is_none());
475 assert!(ops.get_path("/get").is_none());
476 assert!(ops.get_path("/nested/get").is_none());
477
478 assert!(ops.get_path("/get_with_spec").is_some());
479 assert!(ops.get("/get_with_spec", &Method::GET).is_some());
480 assert!(ops.get("/get_with_spec", &Method::POST).is_none());
481 assert!(ops.get_path("/nested/get_with_spec").is_some());
482 assert!(ops.get("/nested/get_with_spec", &Method::GET).is_some());
483 assert!(ops.get("/nested/get_with_spec", &Method::POST).is_none());
484 assert!(ops.get("/nested/get_with_spec", &Method::POST).is_none());
485
486 assert!(ops.get("/my_path", &Method::GET).is_some());
487 assert!(ops.get("/my_path", &Method::POST).is_some());
488
489 let make_service = app.into_make_service();
490 let _ = async move {
491 let listener = TcpListener::bind("").await.unwrap();
492 axum::serve(listener, make_service).await.unwrap()
493 };
494 }
495}