Skip to main content

autapi/
router.rs

1use std::{borrow::Cow, mem};
2
3use axum::routing::{MethodFilter, MethodRouter};
4
5use crate::{
6    Registry, adapters::AxumHandlerAdapter, endpoint::Endpoint, openapi::OpenApi, rapidoc,
7};
8
9/// Wrapper around axum's `Router`, allowing registration of endpoints.
10pub struct Router<S = ()> {
11    axum: axum::Router<S>,
12    registry: Registry,
13    serve_spec_at: Vec<String>,
14    #[expect(clippy::type_complexity)]
15    modify_openapi: Vec<Box<dyn FnOnce(&mut Registry, &S) + Send + Sync + 'static>>,
16}
17
18impl<S> Router<S>
19where
20    S: Clone + Send + Sync + 'static,
21{
22    pub fn endpoint<E: Endpoint<S, V>, V: 'static>(mut self, endpoint: E) -> Self {
23        self.endpoint_with_base("", endpoint);
24        self
25    }
26    fn endpoint_with_base<E: Endpoint<S, V>, V: 'static>(&mut self, base: &str, endpoint: E) {
27        let path = make_path(base, endpoint.path().as_ref());
28        let method = endpoint.method();
29        let filter = MethodFilter::try_from(endpoint.method())
30            .expect("a matching method filter should exist");
31        let operation = endpoint.openapi(&mut self.registry);
32        self.modify_axum(|router| {
33            router.route(
34                path.as_ref(),
35                MethodRouter::default().on(filter, AxumHandlerAdapter(endpoint)),
36            )
37        });
38        let operation_entry = self
39            .registry
40            .openapi_mut()
41            .paths
42            .paths
43            .entry(path.clone())
44            .or_default()
45            .operation_by_method_mut(method.clone())
46            .expect("a matching operation entry should exist in PathItem");
47        if operation_entry.is_some() {
48            panic!("colliding operations for path {path:?} and method {method}");
49        }
50        *operation_entry = Some(operation);
51    }
52    pub fn nest<'r>(&'r mut self, base: &'r str) -> NestedRouter<'r, S> {
53        NestedRouter {
54            router: self,
55            base: base.into(),
56        }
57    }
58    pub fn with_state(self, state: S) -> Router {
59        let cloned_state = state.clone();
60        Router {
61            axum: self.axum.with_state(state),
62            serve_spec_at: self.serve_spec_at,
63            registry: self.registry,
64            modify_openapi: vec![Box::new(move |openapi, _| {
65                for modifier in self.modify_openapi {
66                    modifier(openapi, &cloned_state);
67                }
68            })],
69        }
70    }
71    pub fn modify_axum(&mut self, modifier: impl FnOnce(axum::Router<S>) -> axum::Router<S>) {
72        self.axum = modifier(mem::take(&mut self.axum));
73    }
74    pub fn registry_mut(&mut self) -> &mut Registry {
75        &mut self.registry
76    }
77    pub fn modify_openapi(
78        &mut self,
79        modifier: impl FnOnce(&mut Registry, &S) + Send + Sync + 'static,
80    ) {
81        self.modify_openapi.push(Box::new(modifier));
82    }
83    pub fn serve_docs(mut self, path: &str) -> Self {
84        let serve_spec_at = make_path(path, "openapi.json");
85        self.modify_axum(|router| {
86            router.route(
87                path,
88                axum::routing::get(rapidoc::RapiDoc {
89                    spec_url: serve_spec_at.clone(),
90                }),
91            )
92        });
93        self.serve_spec_at.push(serve_spec_at);
94        self
95    }
96}
97
98impl Router {
99    pub fn into_parts(mut self) -> (axum::Router, OpenApi) {
100        for modifier in self.modify_openapi {
101            modifier(&mut self.registry, &());
102        }
103        let openapi = self.registry.into_openapi();
104
105        let mut axum = self.axum;
106        for spec_at in self.serve_spec_at {
107            axum = axum.route(&spec_at, axum::routing::get(axum::Json(openapi.clone())));
108        }
109        (axum, openapi)
110    }
111
112    #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
113    pub fn serve<L>(self, listener: L) -> axum::serve::Serve<L, axum::Router, axum::Router>
114    where
115        L: axum::serve::Listener,
116    {
117        axum::serve(listener, self.into_parts().0)
118    }
119}
120
121impl<S: Clone + Send + Sync + 'static> Default for Router<S> {
122    fn default() -> Self {
123        Self {
124            axum: Default::default(),
125            registry: Default::default(),
126            serve_spec_at: Default::default(),
127            modify_openapi: Default::default(),
128        }
129    }
130}
131
132/// Configure endpoints for a router starting with a path.
133pub struct NestedRouter<'r, S = ()> {
134    router: &'r mut Router<S>,
135    base: Cow<'r, str>,
136}
137
138impl<'r, S> NestedRouter<'r, S>
139where
140    S: Clone + Send + Sync + 'static,
141{
142    pub fn endpoint<E: Endpoint<S, V>, V: 'static>(self, endpoint: E) -> Self {
143        self.router.endpoint_with_base(&self.base, endpoint);
144        self
145    }
146    pub fn nest(&mut self, base: &str) -> NestedRouter<'_, S> {
147        NestedRouter {
148            router: self.router,
149            base: make_path(&self.base, base).into(),
150        }
151    }
152    pub fn into_nested(self, base: &str) -> NestedRouter<'r, S> {
153        NestedRouter {
154            router: self.router,
155            base: make_path(&self.base, base).into(),
156        }
157    }
158}
159
160fn make_path(base: &str, path: &str) -> String {
161    if base.is_empty() {
162        format!("/{}", path.trim_start_matches('/'))
163    } else {
164        format!(
165            "/{}/{}",
166            base.trim_end_matches('/').trim_start_matches("/"),
167            path.trim_start_matches('/')
168        )
169    }
170}