1use std::convert::Infallible;
10
11use axum::extract::Request;
12use axum::response::IntoResponse;
13use axum::routing::Route;
14use tower::{Layer, Service};
15use utoipa_axum::router::OpenApiRouter;
16
17use crate::contribution::{apply_contribution, DocumentedLayer, LayerContribution};
18
19pub trait OpenApiRouterExt<S>: Sized {
23 fn layer_documented<L>(self, layer: L) -> Self
37 where
38 L: Layer<Route> + DocumentedLayer + Clone + Send + Sync + 'static,
39 L::Service: Service<Request> + Clone + Send + Sync + 'static,
40 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
41 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
42 <L::Service as Service<Request>>::Future: Send + 'static;
43
44 fn tag_all(self, tag: impl Into<String>) -> Self;
63}
64
65impl<S: Clone + Send + Sync + 'static> OpenApiRouterExt<S> for OpenApiRouter<S> {
66 fn layer_documented<L>(mut self, layer: L) -> Self
67 where
68 L: Layer<Route> + DocumentedLayer + Clone + Send + Sync + 'static,
69 L::Service: Service<Request> + Clone + Send + Sync + 'static,
70 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
71 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
72 <L::Service as Service<Request>>::Future: Send + 'static,
73 {
74 let contribution = layer.contribution();
75 if !contribution.is_empty() {
76 apply_contribution(self.get_openapi_mut(), &contribution);
77 }
78 self.layer(layer)
79 }
80
81 fn tag_all(mut self, tag: impl Into<String>) -> Self {
82 let contribution = LayerContribution::new().with_tag(tag);
83 apply_contribution(self.get_openapi_mut(), &contribution);
84 self
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::contribution::LayerContribution;
92 use crate::headers::HeaderParam;
93
94 use std::task::{Context, Poll};
95
96 use axum::body::Body;
97 use axum::http::Response as HttpResponse;
98 use tower::Layer;
99 use utoipa::openapi::path::{HttpMethod, OperationBuilder, PathItem};
100 use utoipa::openapi::response::Responses;
101 use utoipa::openapi::PathsBuilder;
102 use utoipa_axum::router::OpenApiRouter;
103
104 #[derive(Clone)]
108 struct MockDocLayer {
109 header_name: &'static str,
110 }
111
112 impl DocumentedLayer for MockDocLayer {
113 fn contribution(&self) -> LayerContribution {
114 LayerContribution::new().with_header(HeaderParam::required(self.header_name))
115 }
116 }
117
118 impl<Inner> Layer<Inner> for MockDocLayer {
119 type Service = MockDocService<Inner>;
120 fn layer(&self, inner: Inner) -> Self::Service {
121 MockDocService { inner }
122 }
123 }
124
125 #[derive(Clone)]
126 struct MockDocService<Inner> {
127 inner: Inner,
128 }
129
130 impl<Inner> Service<Request> for MockDocService<Inner>
131 where
132 Inner: Service<Request, Response = HttpResponse<Body>, Error = Infallible>
133 + Clone
134 + Send
135 + 'static,
136 Inner::Future: Send + 'static,
137 {
138 type Response = HttpResponse<Body>;
139 type Error = Infallible;
140 type Future = Inner::Future;
141
142 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143 self.inner.poll_ready(cx)
144 }
145 fn call(&mut self, req: Request) -> Self::Future {
146 self.inner.call(req)
147 }
148 }
149
150 fn router_with_path(path: &str) -> OpenApiRouter {
151 let item = PathItem::new(HttpMethod::Get, OperationBuilder::new().build());
152 let paths = PathsBuilder::new().path(path, item).build();
153 let openapi = utoipa::openapi::OpenApiBuilder::new().paths(paths).build();
154 OpenApiRouter::with_openapi(openapi)
155 }
156
157 fn op_for(router: &OpenApiRouter, path: &str) -> utoipa::openapi::path::Operation {
158 router
159 .get_openapi()
160 .paths
161 .paths
162 .get(path)
163 .expect("path present")
164 .get
165 .as_ref()
166 .expect("get operation present")
167 .clone()
168 }
169
170 fn header_names(op: &utoipa::openapi::path::Operation) -> Vec<String> {
171 op.parameters
172 .as_ref()
173 .map(|params| params.iter().map(|p| p.name.clone()).collect())
174 .unwrap_or_default()
175 }
176
177 #[test]
178 fn layer_documented_stamps_contribution_on_current_operations() {
179 let router =
180 router_with_path("/widgets").layer_documented(MockDocLayer { header_name: "X-A" });
181
182 let op = op_for(&router, "/widgets");
183 assert!(header_names(&op).iter().any(|n| n == "X-A"));
184 }
185
186 #[test]
187 fn layer_documented_only_affects_routes_present_before_call() {
188 let router_a = router_with_path("/a");
189 let router_b = router_with_path("/b");
190
191 let merged = router_a
193 .layer_documented(MockDocLayer { header_name: "X-A" })
194 .merge(router_b);
195
196 let op_a = op_for(&merged, "/a");
197 let op_b = op_for(&merged, "/b");
198
199 assert!(
200 header_names(&op_a).iter().any(|n| n == "X-A"),
201 "/a should have the layer's header"
202 );
203 assert!(
204 !header_names(&op_b).iter().any(|n| n == "X-A"),
205 "/b was merged after the layer; must not carry its header"
206 );
207 }
208
209 #[test]
210 fn multiple_layer_documented_calls_accumulate_per_route() {
211 let router = router_with_path("/widgets")
212 .layer_documented(MockDocLayer { header_name: "X-A" })
213 .layer_documented(MockDocLayer { header_name: "X-B" });
214
215 let op = op_for(&router, "/widgets");
216 let names = header_names(&op);
217 assert!(names.iter().any(|n| n == "X-A"), "X-A from first layer");
218 assert!(names.iter().any(|n| n == "X-B"), "X-B from second layer");
219 }
220
221 #[test]
228 fn layer_documented_contribution_survives_merge_into_base() {
229 let base = router_with_path("/health");
230 let protected = router_with_path("/api/v1/models")
231 .layer_documented(MockDocLayer { header_name: "X-A" });
232
233 let merged = base.merge(protected);
234
235 let health_op = op_for(&merged, "/health");
236 let models_op = op_for(&merged, "/api/v1/models");
237
238 assert!(
239 !header_names(&health_op).iter().any(|n| n == "X-A"),
240 "base route /health must not carry the layer's contribution",
241 );
242 assert!(
243 header_names(&models_op).iter().any(|n| n == "X-A"),
244 "merged-in route /api/v1/models must carry the layer's contribution",
245 );
246 }
247
248 #[test]
249 fn documented_layer_with_empty_contribution_is_pure_layer_application() {
250 #[derive(Clone)]
251 struct EmptyLayer;
252 impl DocumentedLayer for EmptyLayer {
253 fn contribution(&self) -> LayerContribution {
254 LayerContribution::new()
255 }
256 }
257 impl<Inner> Layer<Inner> for EmptyLayer {
258 type Service = MockDocService<Inner>;
259 fn layer(&self, inner: Inner) -> Self::Service {
260 MockDocService { inner }
261 }
262 }
263
264 let router = router_with_path("/widgets").layer_documented(EmptyLayer);
265 let op = op_for(&router, "/widgets");
266 assert!(op.parameters.is_none(), "no parameters injected");
267 }
268
269 fn op_tags(op: &utoipa::openapi::path::Operation) -> Vec<String> {
270 op.tags.clone().unwrap_or_default()
271 }
272
273 #[test]
274 fn tag_all_stamps_tag_on_current_operations() {
275 let router = router_with_path("/widgets").tag_all("Widgets");
276
277 let op = op_for(&router, "/widgets");
278 assert_eq!(op_tags(&op), vec!["Widgets".to_string()]);
279 }
280
281 #[test]
282 fn tag_all_does_not_affect_routes_merged_after() {
283 let router_a = router_with_path("/a").tag_all("A");
284 let router_b = router_with_path("/b");
285
286 let merged = router_a.merge(router_b);
287
288 let op_a = op_for(&merged, "/a");
289 let op_b = op_for(&merged, "/b");
290
291 assert_eq!(op_tags(&op_a), vec!["A".to_string()]);
292 assert!(
293 op_tags(&op_b).is_empty(),
294 "/b was merged after tag_all; must not carry the tag"
295 );
296 }
297
298 #[test]
299 fn tag_all_deduplicates_when_called_twice() {
300 let router = router_with_path("/widgets")
301 .tag_all("Widgets")
302 .tag_all("Widgets");
303
304 let op = op_for(&router, "/widgets");
305 assert_eq!(op_tags(&op), vec!["Widgets".to_string()]);
306 }
307
308 #[test]
309 fn tag_all_merges_with_existing_tags() {
310 let mut item = PathItem::new(
313 HttpMethod::Get,
314 OperationBuilder::new().tag("FromHandler").build(),
315 );
316 item.get.as_mut().unwrap().responses = Responses::new();
317 let paths = PathsBuilder::new().path("/widgets", item).build();
318 let openapi = utoipa::openapi::OpenApiBuilder::new().paths(paths).build();
319 let router = OpenApiRouter::with_openapi(openapi).tag_all("FromRouter");
320
321 let op = op_for(&router, "/widgets");
322 let tags = op_tags(&op);
323 assert!(tags.contains(&"FromHandler".to_string()));
324 assert!(tags.contains(&"FromRouter".to_string()));
325 }
326}