Skip to main content

okapi_operation/axum_integration/
router.rs

1use std::{collections::HashMap, convert::Infallible, fmt};
2
3use axum::{
4    Router as AxumRouter, extract::Request, handler::Handler, http::Method, response::IntoResponse,
5    routing::Route,
6};
7use tower::{Layer, Service};
8
9use super::{
10    get,
11    method_router::{MethodRouter, MethodRouterOperations},
12    operations::RoutesOperations,
13};
14use crate::OpenApiBuilder;
15
16pub const DEFAULT_OPENAPI_PATH: &str = "/openapi";
17
18/// Drop-in replacement for [`axum::Router`], which supports OpenAPI operations.
19///
20/// This replacement cannot be used as [`Service`] instead require explicit
21/// convertion of this type to `axum::Router`. This is done to ensure that
22/// OpenAPI specification generated and mounted.
23pub struct Router<S = ()> {
24    axum_router: AxumRouter<S>,
25    routes_operations_map: HashMap<String, MethodRouterOperations>,
26    openapi_builder_template: OpenApiBuilder,
27}
28
29impl<S> From<AxumRouter<S>> for Router<S> {
30    fn from(value: AxumRouter<S>) -> Self {
31        Self {
32            axum_router: value,
33            routes_operations_map: Default::default(),
34            openapi_builder_template: OpenApiBuilder::default(),
35        }
36    }
37}
38
39impl<S> Clone for Router<S>
40where
41    S: Clone + Send + Sync + 'static,
42{
43    fn clone(&self) -> Self {
44        Self {
45            axum_router: self.axum_router.clone(),
46            routes_operations_map: self.routes_operations_map.clone(),
47            openapi_builder_template: self.openapi_builder_template.clone(),
48        }
49    }
50}
51
52impl<S> Default for Router<S>
53where
54    S: Clone + Send + Sync + 'static,
55{
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl<S> fmt::Debug for Router<S>
62where
63    S: fmt::Debug,
64{
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        self.axum_router.fmt(f)
67    }
68}
69
70impl<S> Router<S>
71where
72    S: Clone + Send + Sync + 'static,
73{
74    /// Create new router.
75    pub fn new() -> Self {
76        Self {
77            axum_router: AxumRouter::new(),
78            routes_operations_map: HashMap::new(),
79            openapi_builder_template: OpenApiBuilder::default(),
80        }
81    }
82
83    /// Add another route to the router.
84    ///
85    /// This method works for both [`MethodRouter`] and one from axum.
86    ///
87    /// For details see [`axum::Router::route`].
88    ///
89    /// # Example
90    ///
91    /// ```rust
92    /// # use okapi_operation::{*, axum_integration::*};
93    /// #[openapi]
94    /// async fn handler() {}
95    ///
96    /// let app = Router::new().route("/", get(openapi_handler!(handler)));
97    /// # async {
98    /// # let (app, _) = app.into_parts();
99    /// # let listener = tokio::net::TcpListener::bind("").await.unwrap();
100    /// # axum::serve(listener, app.into_make_service()).await.unwrap()
101    /// # };
102    /// ```
103    pub fn route<R>(mut self, path: &str, method_router: R) -> Self
104    where
105        R: Into<MethodRouter<S>>,
106    {
107        let method_router = method_router.into();
108
109        // Merge operations
110        let s = self.routes_operations_map.entry(path.into()).or_default();
111        *s = s.clone().merge(method_router.operations);
112
113        Self {
114            axum_router: self
115                .axum_router
116                .route(path, method_router.axum_method_router),
117            ..self
118        }
119    }
120
121    /// Add another route to the router that calls a [`Service`].
122    ///
123    /// For details see [`axum::Router::route_service`].
124    ///
125    /// # Example
126    ///
127    /// TODO
128    pub fn route_service<Svc>(self, path: &str, service: Svc) -> Self
129    where
130        Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
131        Svc::Response: IntoResponse,
132        Svc::Future: Send + 'static,
133    {
134        Self {
135            axum_router: self.axum_router.route_service(path, service),
136            ..self
137        }
138    }
139
140    /// Nest a router at some path.
141    ///
142    /// This method works for both [`Router`] and one from axum.
143    ///
144    /// For details see [`axum::Router::nest`].
145    ///
146    /// # Example
147    ///
148    /// ```rust
149    /// # use okapi_operation::{*, axum_integration::*};
150    /// #[openapi]
151    /// async fn handler() {}
152    /// let handler_router = Router::new().route("/", get(openapi_handler!(handler)));
153    /// let app = Router::new().nest("/handle", handler_router);
154    /// # async {
155    /// # let listener = tokio::net::TcpListener::bind("").await.unwrap();
156    /// # axum::serve(listener, app.into_parts().0.into_make_service()).await.unwrap()
157    /// # };
158    /// ```
159    pub fn nest<R>(mut self, path: &str, router: R) -> Self
160    where
161        R: Into<Router<S>>,
162    {
163        let router = router.into();
164        for (inner_path, operation) in router.routes_operations_map.into_iter() {
165            let _ = self
166                .routes_operations_map
167                .insert(format!("{}{}", path, inner_path), operation);
168        }
169        Self {
170            axum_router: self.axum_router.nest(path, router.axum_router),
171            ..self
172        }
173    }
174
175    /// Like `nest`, but accepts an arbitrary [`Service`].
176    ///
177    /// For details see [`axum::Router::nest_service`].
178    pub fn nest_service<Svc>(self, path: &str, svc: Svc) -> Self
179    where
180        Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
181        Svc::Response: IntoResponse,
182        Svc::Future: Send + 'static,
183    {
184        Self {
185            axum_router: self.axum_router.nest_service(path, svc),
186            ..self
187        }
188    }
189
190    /// Merge two routers into one.
191    ///
192    /// This method works for both [`Router`] and one from axum.
193    ///
194    /// For details see [`axum::Router::merge`].
195    ///
196    /// # Example
197    ///
198    /// ```rust
199    /// # use okapi_operation::{*, axum_integration::*};
200    /// #[openapi]
201    /// async fn handler() {}
202    /// let handler_router = Router::new().route("/another_handler", get(openapi_handler!(handler)));
203    /// let app = Router::new().route("/", get(openapi_handler!(handler))).merge(handler_router);
204    /// # async {
205    /// # let (app, _) = app.into_parts();
206    /// # let listener = tokio::net::TcpListener::bind("").await.unwrap();
207    /// # axum::serve(listener, app.into_make_service()).await.unwrap()
208    /// # };
209    /// ```
210    pub fn merge<R>(mut self, other: R) -> Self
211    where
212        R: Into<Router<S>>,
213    {
214        let other = other.into();
215        self.routes_operations_map
216            .extend(other.routes_operations_map);
217        Self {
218            axum_router: self.axum_router.merge(other.axum_router),
219            ..self
220        }
221    }
222
223    /// Apply a [`tower::Layer`] to the router.
224    ///
225    /// For details see [`axum::Router::layer`].
226    pub fn layer<L>(self, layer: L) -> Router<S>
227    where
228        L: Layer<Route> + Clone + Send + Sync + 'static,
229        L::Service: Service<Request> + Clone + Send + Sync + 'static,
230        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
231        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
232        <L::Service as Service<Request>>::Future: Send + 'static,
233    {
234        Router {
235            axum_router: self.axum_router.layer(layer),
236            routes_operations_map: self.routes_operations_map,
237            openapi_builder_template: self.openapi_builder_template,
238        }
239    }
240
241    /// Apply a [`tower::Layer`] to the router that will only run if the request matches a route.
242    ///
243    /// For details see [`axum::Router::route_layer`].
244    pub fn route_layer<L>(self, layer: L) -> Self
245    where
246        L: Layer<Route> + Clone + Send + Sync + 'static,
247        L::Service: Service<Request> + Clone + Send + Sync + 'static,
248        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
249        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
250        <L::Service as Service<Request>>::Future: Send + 'static,
251    {
252        Router {
253            axum_router: self.axum_router.route_layer(layer),
254            routes_operations_map: self.routes_operations_map,
255            openapi_builder_template: self.openapi_builder_template,
256        }
257    }
258
259    // TODO: somehow mount openapi doc from this handler
260    /// Add a fallback [`Service`] to the router.
261    ///
262    /// For details see [`axum::Router::fallback_service`].
263    ///
264    /// # Note
265    ///
266    /// This method doesn't add anything to OpenaAPI spec.
267    pub fn fallback<H, T>(self, handler: H) -> Self
268    where
269        H: Handler<T, S>,
270        T: 'static,
271    {
272        Router {
273            axum_router: self.axum_router.fallback(handler),
274            ..self
275        }
276    }
277
278    /// Add a fallback [`Service`] to the router.
279    ///
280    /// For details see [`axum::Router::fallback_service`].
281    ///
282    /// # Note
283    ///
284    /// This method doesn't add anything to OpenaAPI spec.
285    pub fn fallback_service<Svc>(self, svc: Svc) -> Self
286    where
287        Svc: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
288        Svc::Response: IntoResponse,
289        Svc::Future: Send + 'static,
290    {
291        Router {
292            axum_router: self.axum_router.fallback_service(svc),
293            ..self
294        }
295    }
296
297    /// Provide the state for the router.
298    ///
299    /// For details see [`axum::Router::with_state`].
300    pub fn with_state<S2>(self, state: S) -> Router<S2> {
301        Router {
302            axum_router: self.axum_router.with_state(state),
303            routes_operations_map: self.routes_operations_map,
304            openapi_builder_template: self.openapi_builder_template,
305        }
306    }
307
308    /// Separate router into [`axum::Router`] and list of operations.
309    pub fn into_parts(self) -> (AxumRouter<S>, RoutesOperations) {
310        (
311            self.axum_router,
312            RoutesOperations::new(self.routes_operations_map),
313        )
314    }
315
316    /// Get inner [`axum::Router`].
317    pub fn axum_router(&self) -> AxumRouter<S> {
318        self.axum_router.clone()
319    }
320
321    /// Get list of operations.
322    pub fn routes_operations(&self) -> RoutesOperations {
323        RoutesOperations::new(self.routes_operations_map.clone())
324    }
325
326    /// Generate [`OpenApiBuilder`] from current router.
327    ///
328    /// Generated builder will be based on current builder template,
329    /// have all routes and types, present in this router.
330    ///
331    /// If template was not set, then [`OpenApiBuilder::default()`] is used.
332    pub fn generate_openapi_builder(&self) -> OpenApiBuilder {
333        let routes = self.routes_operations().openapi_operation_generators();
334        let mut builder = self.openapi_builder_template.clone();
335        // Don't use try_operations since duplicates should be checked
336        // when mounting route to axum router.
337        builder.operations(routes.into_iter().map(|((x, y), z)| (x, y, z)));
338        builder
339    }
340
341    /// Set [`OpenApiBuilder`] template for this router.
342    ///
343    /// By default [`OpenApiBuilder::default()`] is used.
344    pub fn set_openapi_builder_template(&mut self, builder: OpenApiBuilder) -> &mut Self {
345        self.openapi_builder_template = builder;
346        self
347    }
348
349    /// Update [`OpenApiBuilder`] template of this router.
350    ///
351    /// By default [`OpenApiBuilder::default()`] is used.
352    pub fn update_openapi_builder_template<F>(&mut self, f: F) -> &mut Self
353    where
354        F: FnOnce(&mut OpenApiBuilder),
355    {
356        f(&mut self.openapi_builder_template);
357        self
358    }
359
360    /// Get mutable reference to [`OpenApiBuilder`] template of this router.
361    ///
362    /// By default [`OpenApiBuilder::default()`] is set.
363    pub fn openapi_builder_template_mut(&mut self) -> &mut OpenApiBuilder {
364        &mut self.openapi_builder_template
365    }
366
367    /// Generate OpenAPI specification, mount it to inner router and return inner [`axum::Router`].
368    ///
369    /// Specification is based on [`OpenApiBuilder`] template, if one was set previously.
370    /// If template was not set, then [`OpenApiBuilder::default()`] is used.
371    ///
372    /// Note that passed `title` and `version` will override same values in OpenAPI builder template.
373    ///
374    /// By default specification served at [`DEFAULT_OPENAPI_PATH`] (`/openapi`).
375    ///
376    /// # Example
377    ///
378    /// ```rust
379    /// # use okapi_operation::{*, axum_integration::*};
380    /// #[openapi]
381    /// async fn handler() {}
382    ///
383    /// let app = Router::new().route("/", get(openapi_handler!(handler)));
384    /// # async {
385    /// let app = app.finish_openapi("/openapi", "Demo", "1.0.0").expect("ok");
386    /// # let listener = tokio::net::TcpListener::bind("").await.unwrap();
387    /// # axum::serve(listener, app.into_make_service()).await.unwrap()
388    /// # };
389    /// ```
390    pub fn finish_openapi<'a>(
391        mut self,
392        serve_path: impl Into<Option<&'a str>>,
393        title: impl Into<String>,
394        version: impl Into<String>,
395    ) -> Result<AxumRouter<S>, anyhow::Error> {
396        let serve_path = serve_path.into().unwrap_or(DEFAULT_OPENAPI_PATH);
397
398        // Don't use try_operation since duplicates should be checked
399        // when mounting route to axum router.
400        let spec = self
401            .generate_openapi_builder()
402            .operation(serve_path, Method::GET, super::serve_openapi_spec__openapi)
403            .title(title)
404            .version(version)
405            .build()?;
406
407        self = self.route(serve_path, get(super::serve_openapi_spec).with_state(spec));
408
409        Ok(self.axum_router)
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    #![allow(clippy::let_underscore_future)]
416
417    use axum::{http::Method, routing::get as axum_get};
418    use okapi::openapi3::Operation;
419    use tokio::net::TcpListener;
420
421    use super::*;
422    use crate::{
423        BuilderOptions, Components,
424        axum_integration::{HandlerExt, get, post},
425    };
426
427    fn openapi_generator(
428        _: &mut Components,
429        _: &BuilderOptions,
430    ) -> Result<Operation, anyhow::Error> {
431        unimplemented!()
432    }
433
434    #[test]
435    fn mount_axum_types() {
436        let axum_router = AxumRouter::new().route("/get", axum_get(|| async {}));
437        let (app, meta) = Router::new()
438            .route("/", axum_get(|| async {}))
439            .nest("/nested", axum_router.clone())
440            .merge(axum_router)
441            .into_parts();
442        assert!(meta.0.is_empty());
443        let make_service = app.into_make_service();
444        let _ = async move {
445            let listener = TcpListener::bind("").await.unwrap();
446            axum::serve(listener, make_service).await.unwrap()
447        };
448    }
449
450    #[test]
451    fn mount() {
452        let router = Router::new().route("/get", get(|| async {})).route(
453            "/get_with_spec",
454            get((|| async {}).with_openapi(openapi_generator)),
455        );
456        let router2 = Router::new().route("/get", get(|| async {})).route(
457            "/get_with_spec",
458            get((|| async {}).with_openapi(openapi_generator)),
459        );
460        let (app, ops) = Router::new()
461            .route("/", get(|| async {}))
462            .nest("/nested", router)
463            .merge(router2)
464            .route(
465                "/my_path",
466                get((|| async {}).with_openapi(openapi_generator)),
467            )
468            .route(
469                "/my_path",
470                post((|| async {}).with_openapi(openapi_generator)),
471            )
472            .into_parts();
473
474        assert!(ops.get_path("/").is_none());
475        assert!(ops.get_path("/get").is_none());
476        assert!(ops.get_path("/nested/get").is_none());
477
478        assert!(ops.get_path("/get_with_spec").is_some());
479        assert!(ops.get("/get_with_spec", &Method::GET).is_some());
480        assert!(ops.get("/get_with_spec", &Method::POST).is_none());
481        assert!(ops.get_path("/nested/get_with_spec").is_some());
482        assert!(ops.get("/nested/get_with_spec", &Method::GET).is_some());
483        assert!(ops.get("/nested/get_with_spec", &Method::POST).is_none());
484        assert!(ops.get("/nested/get_with_spec", &Method::POST).is_none());
485
486        assert!(ops.get("/my_path", &Method::GET).is_some());
487        assert!(ops.get("/my_path", &Method::POST).is_some());
488
489        let make_service = app.into_make_service();
490        let _ = async move {
491            let listener = TcpListener::bind("").await.unwrap();
492            axum::serve(listener, make_service).await.unwrap()
493        };
494    }
495}