1use std::{borrow::Cow, mem};
2
3use axum::routing::{MethodFilter, MethodRouter};
4
5use crate::{
6 Registry, adapters::AxumHandlerAdapter, endpoint::Endpoint, openapi::OpenApi, rapidoc,
7};
8
9pub struct Router<S = ()> {
11 axum: axum::Router<S>,
12 registry: Registry,
13 serve_spec_at: Vec<String>,
14 #[expect(clippy::type_complexity)]
15 modify_openapi: Vec<Box<dyn FnOnce(&mut Registry, &S) + Send + Sync + 'static>>,
16}
17
18impl<S> Router<S>
19where
20 S: Clone + Send + Sync + 'static,
21{
22 pub fn endpoint<E: Endpoint<S, V>, V: 'static>(mut self, endpoint: E) -> Self {
23 self.endpoint_with_base("", endpoint);
24 self
25 }
26 fn endpoint_with_base<E: Endpoint<S, V>, V: 'static>(&mut self, base: &str, endpoint: E) {
27 let path = make_path(base, endpoint.path().as_ref());
28 let method = endpoint.method();
29 let filter = MethodFilter::try_from(endpoint.method())
30 .expect("a matching method filter should exist");
31 let operation = endpoint.openapi(&mut self.registry);
32 self.modify_axum(|router| {
33 router.route(
34 path.as_ref(),
35 MethodRouter::default().on(filter, AxumHandlerAdapter(endpoint)),
36 )
37 });
38 let operation_entry = self
39 .registry
40 .openapi_mut()
41 .paths
42 .paths
43 .entry(path.clone())
44 .or_default()
45 .operation_by_method_mut(method.clone())
46 .expect("a matching operation entry should exist in PathItem");
47 if operation_entry.is_some() {
48 panic!("colliding operations for path {path:?} and method {method}");
49 }
50 *operation_entry = Some(operation);
51 }
52 pub fn nest<'r>(&'r mut self, base: &'r str) -> NestedRouter<'r, S> {
53 NestedRouter {
54 router: self,
55 base: base.into(),
56 }
57 }
58 pub fn with_state(self, state: S) -> Router {
59 let cloned_state = state.clone();
60 Router {
61 axum: self.axum.with_state(state),
62 serve_spec_at: self.serve_spec_at,
63 registry: self.registry,
64 modify_openapi: vec![Box::new(move |openapi, _| {
65 for modifier in self.modify_openapi {
66 modifier(openapi, &cloned_state);
67 }
68 })],
69 }
70 }
71 pub fn modify_axum(&mut self, modifier: impl FnOnce(axum::Router<S>) -> axum::Router<S>) {
72 self.axum = modifier(mem::take(&mut self.axum));
73 }
74 pub fn registry_mut(&mut self) -> &mut Registry {
75 &mut self.registry
76 }
77 pub fn modify_openapi(
78 &mut self,
79 modifier: impl FnOnce(&mut Registry, &S) + Send + Sync + 'static,
80 ) {
81 self.modify_openapi.push(Box::new(modifier));
82 }
83 pub fn serve_docs(mut self, path: &str) -> Self {
84 let serve_spec_at = make_path(path, "openapi.json");
85 self.modify_axum(|router| {
86 router.route(
87 path,
88 axum::routing::get(rapidoc::RapiDoc {
89 spec_url: serve_spec_at.clone(),
90 }),
91 )
92 });
93 self.serve_spec_at.push(serve_spec_at);
94 self
95 }
96}
97
98impl Router {
99 pub fn into_parts(mut self) -> (axum::Router, OpenApi) {
100 for modifier in self.modify_openapi {
101 modifier(&mut self.registry, &());
102 }
103 let openapi = self.registry.into_openapi();
104
105 let mut axum = self.axum;
106 for spec_at in self.serve_spec_at {
107 axum = axum.route(&spec_at, axum::routing::get(axum::Json(openapi.clone())));
108 }
109 (axum, openapi)
110 }
111
112 #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
113 pub fn serve<L>(self, listener: L) -> axum::serve::Serve<L, axum::Router, axum::Router>
114 where
115 L: axum::serve::Listener,
116 {
117 axum::serve(listener, self.into_parts().0)
118 }
119}
120
121impl<S: Clone + Send + Sync + 'static> Default for Router<S> {
122 fn default() -> Self {
123 Self {
124 axum: Default::default(),
125 registry: Default::default(),
126 serve_spec_at: Default::default(),
127 modify_openapi: Default::default(),
128 }
129 }
130}
131
132pub struct NestedRouter<'r, S = ()> {
134 router: &'r mut Router<S>,
135 base: Cow<'r, str>,
136}
137
138impl<'r, S> NestedRouter<'r, S>
139where
140 S: Clone + Send + Sync + 'static,
141{
142 pub fn endpoint<E: Endpoint<S, V>, V: 'static>(self, endpoint: E) -> Self {
143 self.router.endpoint_with_base(&self.base, endpoint);
144 self
145 }
146 pub fn nest(&mut self, base: &str) -> NestedRouter<'_, S> {
147 NestedRouter {
148 router: self.router,
149 base: make_path(&self.base, base).into(),
150 }
151 }
152 pub fn into_nested(self, base: &str) -> NestedRouter<'r, S> {
153 NestedRouter {
154 router: self.router,
155 base: make_path(&self.base, base).into(),
156 }
157 }
158}
159
160fn make_path(base: &str, path: &str) -> String {
161 if base.is_empty() {
162 format!("/{}", path.trim_start_matches('/'))
163 } else {
164 format!(
165 "/{}/{}",
166 base.trim_end_matches('/').trim_start_matches("/"),
167 path.trim_start_matches('/')
168 )
169 }
170}