okapi_operation/axum_integration/
router.rs

1use std::{collections::HashMap, convert::Infallible, fmt};
2
3use axum::{
4    Router as AxumRouter, extract::Request, handler::Handler, http::Method, response::IntoResponse,
5    routing::Route,
6};
7use tower::{Layer, Service};
8
9use super::{
10    get,
11    method_router::{MethodRouter, MethodRouterOperations},
12    operations::RoutesOperations,
13    utils::convert_axum_path_to_openapi,
14};
15use crate::OpenApiBuilder;
16
17pub const DEFAULT_OPENAPI_PATH: &str = "/openapi";
18
19/// Drop-in replacement for [`axum::Router`], which supports OpenAPI operations.
20///
21/// This replacement cannot be used as [`Service`] instead require explicit
22/// convertion of this type to `axum::Router`. This is done to ensure that
23/// OpenAPI specification generated and mounted.
24pub struct Router<S = ()> {
25    axum_router: AxumRouter<S>,
26    routes_operations_map: HashMap<String, MethodRouterOperations>,
27    openapi_builder_template: OpenApiBuilder,
28}
29
30impl<S> From<AxumRouter<S>> for Router<S> {
31    fn from(value: AxumRouter<S>) -> Self {
32        Self {
33            axum_router: value,
34            routes_operations_map: Default::default(),
35            openapi_builder_template: OpenApiBuilder::default(),
36        }
37    }
38}
39
40impl<S> Clone for Router<S>
41where
42    S: Clone + Send + Sync + 'static,
43{
44    fn clone(&self) -> Self {
45        Self {
46            axum_router: self.axum_router.clone(),
47            routes_operations_map: self.routes_operations_map.clone(),
48            openapi_builder_template: self.openapi_builder_template.clone(),
49        }
50    }
51}
52
53impl<S> Default for Router<S>
54where
55    S: Clone + Send + Sync + 'static,
56{
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl<S> fmt::Debug for Router<S>
63where
64    S: fmt::Debug,
65{
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        self.axum_router.fmt(f)
68    }
69}
70
71impl<S> Router<S>
72where
73    S: Clone + Send + Sync + 'static,
74{
75    /// Create new router.
76    pub fn new() -> Self {
77        Self {
78            axum_router: AxumRouter::new(),
79            routes_operations_map: HashMap::new(),
80            openapi_builder_template: OpenApiBuilder::default(),
81        }
82    }
83
84    /// Add another route to the router.
85    ///
86    /// This method works for both [`MethodRouter`] and one from axum.
87    ///
88    /// For details see [`axum::Router::route`].
89    ///
90    /// # Example
91    ///
92    /// ```rust
93    /// # use okapi_operation::{*, axum_integration::*};
94    /// #[openapi]
95    /// async fn handler() {}
96    ///
97    /// let app = Router::new().route("/", get(openapi_handler!(handler)));
98    /// # async {
99    /// # let (app, _) = app.into_parts();
100    /// # let listener = tokio::net::TcpListener::bind("").await.unwrap();
101    /// # axum::serve(listener, app.into_make_service()).await.unwrap()
102    /// # };
103    /// ```
104    pub fn route<R>(mut self, path: &str, method_router: R) -> Self
105    where
106        R: Into<MethodRouter<S>>,
107    {
108        let method_router = method_router.into();
109
110        // Merge operations
111        let s = self.routes_operations_map.entry(path.into()).or_default();
112        *s = s.clone().merge(method_router.operations);
113
114        Self {
115            axum_router: self
116                .axum_router
117                .route(path, method_router.axum_method_router),
118            ..self
119        }
120    }
121
122    /// Add another route to the router that calls a [`Service`].
123    ///
124    /// For details see [`axum::Router::route_service`].
125    ///
126    /// # Example
127    ///
128    /// TODO
129    pub fn route_service<Svc>(self, path: &str, service: Svc) -> Self
130    where
131        Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
132        Svc::Response: IntoResponse,
133        Svc::Future: Send + 'static,
134    {
135        Self {
136            axum_router: self.axum_router.route_service(path, service),
137            ..self
138        }
139    }
140
141    /// Nest a router at some path.
142    ///
143    /// This method works for both [`Router`] and one from axum.
144    ///
145    /// For details see [`axum::Router::nest`].
146    ///
147    /// # Example
148    ///
149    /// ```rust
150    /// # use okapi_operation::{*, axum_integration::*};
151    /// #[openapi]
152    /// async fn handler() {}
153    /// let handler_router = Router::new().route("/", get(openapi_handler!(handler)));
154    /// let app = Router::new().nest("/handle", handler_router);
155    /// # async {
156    /// # let listener = tokio::net::TcpListener::bind("").await.unwrap();
157    /// # axum::serve(listener, app.into_parts().0.into_make_service()).await.unwrap()
158    /// # };
159    /// ```
160    pub fn nest<R>(mut self, path: &str, router: R) -> Self
161    where
162        R: Into<Router<S>>,
163    {
164        let router = router.into();
165        for (inner_path, operation) in router.routes_operations_map.into_iter() {
166            let _ = self
167                .routes_operations_map
168                .insert(format!("{}{}", path, inner_path), operation);
169        }
170        Self {
171            axum_router: self.axum_router.nest(path, router.axum_router),
172            ..self
173        }
174    }
175
176    /// Like `nest`, but accepts an arbitrary [`Service`].
177    ///
178    /// For details see [`axum::Router::nest_service`].
179    pub fn nest_service<Svc>(self, path: &str, svc: Svc) -> Self
180    where
181        Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
182        Svc::Response: IntoResponse,
183        Svc::Future: Send + 'static,
184    {
185        Self {
186            axum_router: self.axum_router.nest_service(path, svc),
187            ..self
188        }
189    }
190
191    /// Merge two routers into one.
192    ///
193    /// This method works for both [`Router`] and one from axum.
194    ///
195    /// For details see [`axum::Router::merge`].
196    ///
197    /// # Example
198    ///
199    /// ```rust
200    /// # use okapi_operation::{*, axum_integration::*};
201    /// #[openapi]
202    /// async fn handler() {}
203    /// let handler_router = Router::new().route("/another_handler", get(openapi_handler!(handler)));
204    /// let app = Router::new().route("/", get(openapi_handler!(handler))).merge(handler_router);
205    /// # async {
206    /// # let (app, _) = app.into_parts();
207    /// # let listener = tokio::net::TcpListener::bind("").await.unwrap();
208    /// # axum::serve(listener, app.into_make_service()).await.unwrap()
209    /// # };
210    /// ```
211    pub fn merge<R>(mut self, other: R) -> Self
212    where
213        R: Into<Router<S>>,
214    {
215        let other = other.into();
216        self.routes_operations_map
217            .extend(other.routes_operations_map);
218        Self {
219            axum_router: self.axum_router.merge(other.axum_router),
220            ..self
221        }
222    }
223
224    /// Apply a [`tower::Layer`] to the router.
225    ///
226    /// For details see [`axum::Router::layer`].
227    pub fn layer<L>(self, layer: L) -> Router<S>
228    where
229        L: Layer<Route> + Clone + Send + Sync + 'static,
230        L::Service: Service<Request> + Clone + Send + Sync + 'static,
231        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
232        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
233        <L::Service as Service<Request>>::Future: Send + 'static,
234    {
235        Router {
236            axum_router: self.axum_router.layer(layer),
237            routes_operations_map: self.routes_operations_map,
238            openapi_builder_template: self.openapi_builder_template,
239        }
240    }
241
242    /// Apply a [`tower::Layer`] to the router that will only run if the request matches a route.
243    ///
244    /// For details see [`axum::Router::route_layer`].
245    pub fn route_layer<L>(self, layer: L) -> Self
246    where
247        L: Layer<Route> + Clone + Send + Sync + 'static,
248        L::Service: Service<Request> + Clone + Send + Sync + 'static,
249        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
250        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
251        <L::Service as Service<Request>>::Future: Send + 'static,
252    {
253        Router {
254            axum_router: self.axum_router.route_layer(layer),
255            routes_operations_map: self.routes_operations_map,
256            openapi_builder_template: self.openapi_builder_template,
257        }
258    }
259
260    // TODO: somehow mount openapi doc from this handler
261    /// Add a fallback [`Service`] to the router.
262    ///
263    /// For details see [`axum::Router::fallback_service`].
264    ///
265    /// # Note
266    ///
267    /// This method doesn't add anything to OpenaAPI spec.
268    pub fn fallback<H, T>(self, handler: H) -> Self
269    where
270        H: Handler<T, S>,
271        T: 'static,
272    {
273        Router {
274            axum_router: self.axum_router.fallback(handler),
275            ..self
276        }
277    }
278
279    /// Add a fallback [`Service`] to the router.
280    ///
281    /// For details see [`axum::Router::fallback_service`].
282    ///
283    /// # Note
284    ///
285    /// This method doesn't add anything to OpenaAPI spec.
286    pub fn fallback_service<Svc>(self, svc: Svc) -> Self
287    where
288        Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
289        Svc::Response: IntoResponse,
290        Svc::Future: Send + 'static,
291    {
292        Router {
293            axum_router: self.axum_router.fallback_service(svc),
294            ..self
295        }
296    }
297
298    /// Provide the state for the router.
299    ///
300    /// For details see [`axum::Router::with_state`].
301    pub fn with_state<S2>(self, state: S) -> Router<S2> {
302        Router {
303            axum_router: self.axum_router.with_state(state),
304            routes_operations_map: self.routes_operations_map,
305            openapi_builder_template: self.openapi_builder_template,
306        }
307    }
308
309    /// Separate router into [`axum::Router`] and list of operations.
310    pub fn into_parts(self) -> (AxumRouter<S>, RoutesOperations) {
311        (
312            self.axum_router,
313            RoutesOperations::new(self.routes_operations_map),
314        )
315    }
316
317    /// Get inner [`axum::Router`].
318    pub fn axum_router(&self) -> AxumRouter<S> {
319        self.axum_router.clone()
320    }
321
322    /// Get list of operations.
323    pub fn routes_operations(&self) -> RoutesOperations {
324        RoutesOperations::new(self.routes_operations_map.clone())
325    }
326
327    /// Generate [`OpenApiBuilder`] from current router.
328    ///
329    /// Generated builder will be based on current builder template,
330    /// have all routes and types, present in this router.
331    ///
332    /// If template was not set, then [`OpenApiBuilder::default()`] is used.
333    pub fn generate_openapi_builder(&self) -> OpenApiBuilder {
334        let routes = self.routes_operations().openapi_operation_generators();
335        let mut builder = self.openapi_builder_template.clone();
336        // Don't use try_operations since duplicates should be checked
337        // when mounting route to axum router.
338        builder.operations(
339            routes
340                .into_iter()
341                .map(|((x, y), z)| (convert_axum_path_to_openapi(&x), y, z)),
342        );
343        builder
344    }
345
346    /// Set [`OpenApiBuilder`] template for this router.
347    ///
348    /// By default [`OpenApiBuilder::default()`] is used.
349    pub fn set_openapi_builder_template(&mut self, builder: OpenApiBuilder) -> &mut Self {
350        self.openapi_builder_template = builder;
351        self
352    }
353
354    /// Update [`OpenApiBuilder`] template of this router.
355    ///
356    /// By default [`OpenApiBuilder::default()`] is used.
357    pub fn update_openapi_builder_template<F>(&mut self, f: F) -> &mut Self
358    where
359        F: FnOnce(&mut OpenApiBuilder),
360    {
361        f(&mut self.openapi_builder_template);
362        self
363    }
364
365    /// Get mutable reference to [`OpenApiBuilder`] template of this router.
366    ///
367    /// By default [`OpenApiBuilder::default()`] is set.
368    pub fn openapi_builder_template_mut(&mut self) -> &mut OpenApiBuilder {
369        &mut self.openapi_builder_template
370    }
371
372    /// Generate OpenAPI specification, mount it to inner router and return inner [`axum::Router`].
373    ///
374    /// Specification is based on [`OpenApiBuilder`] template, if one was set previously.
375    /// If template was not set, then [`OpenApiBuilder::default()`] is used.
376    ///
377    /// Note that passed `title` and `version` will override same values in OpenAPI builder template.
378    ///
379    /// By default specification served at [`DEFAULT_OPENAPI_PATH`] (`/openapi`).
380    ///
381    /// # Example
382    ///
383    /// ```rust
384    /// # use okapi_operation::{*, axum_integration::*};
385    /// #[openapi]
386    /// async fn handler() {}
387    ///
388    /// let app = Router::new().route("/", get(openapi_handler!(handler)));
389    /// # async {
390    /// let app = app.finish_openapi("/openapi", "Demo", "1.0.0").expect("ok");
391    /// # let listener = tokio::net::TcpListener::bind("").await.unwrap();
392    /// # axum::serve(listener, app.into_make_service()).await.unwrap()
393    /// # };
394    /// ```
395    pub fn finish_openapi<'a>(
396        mut self,
397        serve_path: impl Into<Option<&'a str>>,
398        title: impl Into<String>,
399        version: impl Into<String>,
400    ) -> Result<AxumRouter<S>, anyhow::Error> {
401        let serve_path = serve_path.into().unwrap_or(DEFAULT_OPENAPI_PATH);
402
403        // Don't use try_operation since duplicates should be checked
404        // when mounting route to axum router.
405        let spec = self
406            .generate_openapi_builder()
407            .operation(
408                convert_axum_path_to_openapi(serve_path),
409                Method::GET,
410                super::serve_openapi_spec__openapi,
411            )
412            .title(title)
413            .version(version)
414            .build()?;
415
416        self = self.route(serve_path, get(super::serve_openapi_spec).with_state(spec));
417
418        Ok(self.axum_router)
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    #![allow(clippy::let_underscore_future)]
425
426    use axum::{http::Method, routing::get as axum_get};
427    use okapi::openapi3::Operation;
428    use tokio::net::TcpListener;
429
430    use super::*;
431    use crate::{
432        Components,
433        axum_integration::{HandlerExt, get, post},
434    };
435
436    fn openapi_generator(_: &mut Components) -> Result<Operation, anyhow::Error> {
437        unimplemented!()
438    }
439
440    #[test]
441    fn mount_axum_types() {
442        let axum_router = AxumRouter::new().route("/get", axum_get(|| async {}));
443        let (app, meta) = Router::new()
444            .route("/", axum_get(|| async {}))
445            .nest("/nested", axum_router.clone())
446            .merge(axum_router)
447            .into_parts();
448        assert!(meta.0.is_empty());
449        let make_service = app.into_make_service();
450        let _ = async move {
451            let listener = TcpListener::bind("").await.unwrap();
452            axum::serve(listener, make_service).await.unwrap()
453        };
454    }
455
456    #[test]
457    fn mount() {
458        let router = Router::new().route("/get", get(|| async {})).route(
459            "/get_with_spec",
460            get((|| async {}).with_openapi(openapi_generator)),
461        );
462        let router2 = Router::new().route("/get", get(|| async {})).route(
463            "/get_with_spec",
464            get((|| async {}).with_openapi(openapi_generator)),
465        );
466        let (app, ops) = Router::new()
467            .route("/", get(|| async {}))
468            .nest("/nested", router)
469            .merge(router2)
470            .route(
471                "/my_path",
472                get((|| async {}).with_openapi(openapi_generator)),
473            )
474            .route(
475                "/my_path",
476                post((|| async {}).with_openapi(openapi_generator)),
477            )
478            .into_parts();
479
480        assert!(ops.get_path("/").is_none());
481        assert!(ops.get_path("/get").is_none());
482        assert!(ops.get_path("/nested/get").is_none());
483
484        assert!(ops.get_path("/get_with_spec").is_some());
485        assert!(ops.get("/get_with_spec", &Method::GET).is_some());
486        assert!(ops.get("/get_with_spec", &Method::POST).is_none());
487        assert!(ops.get_path("/nested/get_with_spec").is_some());
488        assert!(ops.get("/nested/get_with_spec", &Method::GET).is_some());
489        assert!(ops.get("/nested/get_with_spec", &Method::POST).is_none());
490        assert!(ops.get("/nested/get_with_spec", &Method::POST).is_none());
491
492        assert!(ops.get("/my_path", &Method::GET).is_some());
493        assert!(ops.get("/my_path", &Method::POST).is_some());
494
495        let make_service = app.into_make_service();
496        let _ = async move {
497            let listener = TcpListener::bind("").await.unwrap();
498            axum::serve(listener, make_service).await.unwrap()
499        };
500    }
501}