okapi_operation/axum_integration/
router.rs1use std::{collections::HashMap, convert::Infallible, fmt};
2
3use axum::{
4 Router as AxumRouter, extract::Request, handler::Handler, http::Method, response::IntoResponse,
5 routing::Route,
6};
7use tower::{Layer, Service};
8
9use super::{
10 get,
11 method_router::{MethodRouter, MethodRouterOperations},
12 operations::RoutesOperations,
13 utils::convert_axum_path_to_openapi,
14};
15use crate::OpenApiBuilder;
16
17pub const DEFAULT_OPENAPI_PATH: &str = "/openapi";
18
19pub struct Router<S = ()> {
25 axum_router: AxumRouter<S>,
26 routes_operations_map: HashMap<String, MethodRouterOperations>,
27 openapi_builder_template: OpenApiBuilder,
28}
29
30impl<S> From<AxumRouter<S>> for Router<S> {
31 fn from(value: AxumRouter<S>) -> Self {
32 Self {
33 axum_router: value,
34 routes_operations_map: Default::default(),
35 openapi_builder_template: OpenApiBuilder::default(),
36 }
37 }
38}
39
40impl<S> Clone for Router<S>
41where
42 S: Clone + Send + Sync + 'static,
43{
44 fn clone(&self) -> Self {
45 Self {
46 axum_router: self.axum_router.clone(),
47 routes_operations_map: self.routes_operations_map.clone(),
48 openapi_builder_template: self.openapi_builder_template.clone(),
49 }
50 }
51}
52
53impl<S> Default for Router<S>
54where
55 S: Clone + Send + Sync + 'static,
56{
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl<S> fmt::Debug for Router<S>
63where
64 S: fmt::Debug,
65{
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 self.axum_router.fmt(f)
68 }
69}
70
71impl<S> Router<S>
72where
73 S: Clone + Send + Sync + 'static,
74{
75 pub fn new() -> Self {
77 Self {
78 axum_router: AxumRouter::new(),
79 routes_operations_map: HashMap::new(),
80 openapi_builder_template: OpenApiBuilder::default(),
81 }
82 }
83
84 pub fn route<R>(mut self, path: &str, method_router: R) -> Self
105 where
106 R: Into<MethodRouter<S>>,
107 {
108 let method_router = method_router.into();
109
110 let s = self.routes_operations_map.entry(path.into()).or_default();
112 *s = s.clone().merge(method_router.operations);
113
114 Self {
115 axum_router: self
116 .axum_router
117 .route(path, method_router.axum_method_router),
118 ..self
119 }
120 }
121
122 pub fn route_service<Svc>(self, path: &str, service: Svc) -> Self
130 where
131 Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
132 Svc::Response: IntoResponse,
133 Svc::Future: Send + 'static,
134 {
135 Self {
136 axum_router: self.axum_router.route_service(path, service),
137 ..self
138 }
139 }
140
141 pub fn nest<R>(mut self, path: &str, router: R) -> Self
161 where
162 R: Into<Router<S>>,
163 {
164 let router = router.into();
165 for (inner_path, operation) in router.routes_operations_map.into_iter() {
166 let _ = self
167 .routes_operations_map
168 .insert(format!("{}{}", path, inner_path), operation);
169 }
170 Self {
171 axum_router: self.axum_router.nest(path, router.axum_router),
172 ..self
173 }
174 }
175
176 pub fn nest_service<Svc>(self, path: &str, svc: Svc) -> Self
180 where
181 Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
182 Svc::Response: IntoResponse,
183 Svc::Future: Send + 'static,
184 {
185 Self {
186 axum_router: self.axum_router.nest_service(path, svc),
187 ..self
188 }
189 }
190
191 pub fn merge<R>(mut self, other: R) -> Self
212 where
213 R: Into<Router<S>>,
214 {
215 let other = other.into();
216 self.routes_operations_map
217 .extend(other.routes_operations_map);
218 Self {
219 axum_router: self.axum_router.merge(other.axum_router),
220 ..self
221 }
222 }
223
224 pub fn layer<L>(self, layer: L) -> Router<S>
228 where
229 L: Layer<Route> + Clone + Send + Sync + 'static,
230 L::Service: Service<Request> + Clone + Send + Sync + 'static,
231 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
232 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
233 <L::Service as Service<Request>>::Future: Send + 'static,
234 {
235 Router {
236 axum_router: self.axum_router.layer(layer),
237 routes_operations_map: self.routes_operations_map,
238 openapi_builder_template: self.openapi_builder_template,
239 }
240 }
241
242 pub fn route_layer<L>(self, layer: L) -> Self
246 where
247 L: Layer<Route> + Clone + Send + Sync + 'static,
248 L::Service: Service<Request> + Clone + Send + Sync + 'static,
249 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
250 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
251 <L::Service as Service<Request>>::Future: Send + 'static,
252 {
253 Router {
254 axum_router: self.axum_router.route_layer(layer),
255 routes_operations_map: self.routes_operations_map,
256 openapi_builder_template: self.openapi_builder_template,
257 }
258 }
259
260 pub fn fallback<H, T>(self, handler: H) -> Self
269 where
270 H: Handler<T, S>,
271 T: 'static,
272 {
273 Router {
274 axum_router: self.axum_router.fallback(handler),
275 ..self
276 }
277 }
278
279 pub fn fallback_service<Svc>(self, svc: Svc) -> Self
287 where
288 Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
289 Svc::Response: IntoResponse,
290 Svc::Future: Send + 'static,
291 {
292 Router {
293 axum_router: self.axum_router.fallback_service(svc),
294 ..self
295 }
296 }
297
298 pub fn with_state<S2>(self, state: S) -> Router<S2> {
302 Router {
303 axum_router: self.axum_router.with_state(state),
304 routes_operations_map: self.routes_operations_map,
305 openapi_builder_template: self.openapi_builder_template,
306 }
307 }
308
309 pub fn into_parts(self) -> (AxumRouter<S>, RoutesOperations) {
311 (
312 self.axum_router,
313 RoutesOperations::new(self.routes_operations_map),
314 )
315 }
316
317 pub fn axum_router(&self) -> AxumRouter<S> {
319 self.axum_router.clone()
320 }
321
322 pub fn routes_operations(&self) -> RoutesOperations {
324 RoutesOperations::new(self.routes_operations_map.clone())
325 }
326
327 pub fn generate_openapi_builder(&self) -> OpenApiBuilder {
334 let routes = self.routes_operations().openapi_operation_generators();
335 let mut builder = self.openapi_builder_template.clone();
336 builder.operations(
339 routes
340 .into_iter()
341 .map(|((x, y), z)| (convert_axum_path_to_openapi(&x), y, z)),
342 );
343 builder
344 }
345
346 pub fn set_openapi_builder_template(&mut self, builder: OpenApiBuilder) -> &mut Self {
350 self.openapi_builder_template = builder;
351 self
352 }
353
354 pub fn update_openapi_builder_template<F>(&mut self, f: F) -> &mut Self
358 where
359 F: FnOnce(&mut OpenApiBuilder),
360 {
361 f(&mut self.openapi_builder_template);
362 self
363 }
364
365 pub fn openapi_builder_template_mut(&mut self) -> &mut OpenApiBuilder {
369 &mut self.openapi_builder_template
370 }
371
372 pub fn finish_openapi<'a>(
396 mut self,
397 serve_path: impl Into<Option<&'a str>>,
398 title: impl Into<String>,
399 version: impl Into<String>,
400 ) -> Result<AxumRouter<S>, anyhow::Error> {
401 let serve_path = serve_path.into().unwrap_or(DEFAULT_OPENAPI_PATH);
402
403 let spec = self
406 .generate_openapi_builder()
407 .operation(
408 convert_axum_path_to_openapi(serve_path),
409 Method::GET,
410 super::serve_openapi_spec__openapi,
411 )
412 .title(title)
413 .version(version)
414 .build()?;
415
416 self = self.route(serve_path, get(super::serve_openapi_spec).with_state(spec));
417
418 Ok(self.axum_router)
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 #![allow(clippy::let_underscore_future)]
425
426 use axum::{http::Method, routing::get as axum_get};
427 use okapi::openapi3::Operation;
428 use tokio::net::TcpListener;
429
430 use super::*;
431 use crate::{
432 Components,
433 axum_integration::{HandlerExt, get, post},
434 };
435
436 fn openapi_generator(_: &mut Components) -> Result<Operation, anyhow::Error> {
437 unimplemented!()
438 }
439
440 #[test]
441 fn mount_axum_types() {
442 let axum_router = AxumRouter::new().route("/get", axum_get(|| async {}));
443 let (app, meta) = Router::new()
444 .route("/", axum_get(|| async {}))
445 .nest("/nested", axum_router.clone())
446 .merge(axum_router)
447 .into_parts();
448 assert!(meta.0.is_empty());
449 let make_service = app.into_make_service();
450 let _ = async move {
451 let listener = TcpListener::bind("").await.unwrap();
452 axum::serve(listener, make_service).await.unwrap()
453 };
454 }
455
456 #[test]
457 fn mount() {
458 let router = Router::new().route("/get", get(|| async {})).route(
459 "/get_with_spec",
460 get((|| async {}).with_openapi(openapi_generator)),
461 );
462 let router2 = Router::new().route("/get", get(|| async {})).route(
463 "/get_with_spec",
464 get((|| async {}).with_openapi(openapi_generator)),
465 );
466 let (app, ops) = Router::new()
467 .route("/", get(|| async {}))
468 .nest("/nested", router)
469 .merge(router2)
470 .route(
471 "/my_path",
472 get((|| async {}).with_openapi(openapi_generator)),
473 )
474 .route(
475 "/my_path",
476 post((|| async {}).with_openapi(openapi_generator)),
477 )
478 .into_parts();
479
480 assert!(ops.get_path("/").is_none());
481 assert!(ops.get_path("/get").is_none());
482 assert!(ops.get_path("/nested/get").is_none());
483
484 assert!(ops.get_path("/get_with_spec").is_some());
485 assert!(ops.get("/get_with_spec", &Method::GET).is_some());
486 assert!(ops.get("/get_with_spec", &Method::POST).is_none());
487 assert!(ops.get_path("/nested/get_with_spec").is_some());
488 assert!(ops.get("/nested/get_with_spec", &Method::GET).is_some());
489 assert!(ops.get("/nested/get_with_spec", &Method::POST).is_none());
490 assert!(ops.get("/nested/get_with_spec", &Method::POST).is_none());
491
492 assert!(ops.get("/my_path", &Method::GET).is_some());
493 assert!(ops.get("/my_path", &Method::POST).is_some());
494
495 let make_service = app.into_make_service();
496 let _ = async move {
497 let listener = TcpListener::bind("").await.unwrap();
498 axum::serve(listener, make_service).await.unwrap()
499 };
500 }
501}