silent_openapi/
middleware.rs1use crate::{OpenApiError, Result, SwaggerUiOptions};
6use async_trait::async_trait;
7use silent::{Handler, MiddleWareHandler, Next, Request, Response, StatusCode};
8use utoipa::openapi::OpenApi;
9
10#[derive(Clone)]
15pub struct SwaggerUiMiddleware {
16 ui_path: String,
18 api_doc_path: String,
20 openapi_json: String,
22 options: SwaggerUiOptions,
24}
25
26impl SwaggerUiMiddleware {
27 pub fn new(ui_path: &str, openapi: OpenApi) -> Result<Self> {
52 let api_doc_path = format!("{}/openapi.json", ui_path.trim_end_matches('/'));
53 let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
54
55 Ok(Self {
56 ui_path: ui_path.to_string(),
57 api_doc_path,
58 openapi_json,
59 options: SwaggerUiOptions::default(),
60 })
61 }
62
63 pub fn with_custom_api_doc_path(
65 ui_path: &str,
66 api_doc_path: &str,
67 openapi: OpenApi,
68 ) -> Result<Self> {
69 let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
70
71 Ok(Self {
72 ui_path: ui_path.to_string(),
73 api_doc_path: api_doc_path.to_string(),
74 openapi_json,
75 options: SwaggerUiOptions::default(),
76 })
77 }
78
79 pub fn with_options(
81 ui_path: &str,
82 openapi: OpenApi,
83 options: SwaggerUiOptions,
84 ) -> Result<Self> {
85 let api_doc_path = format!("{}/openapi.json", ui_path.trim_end_matches('/'));
86 let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
87
88 Ok(Self {
89 ui_path: ui_path.to_string(),
90 api_doc_path,
91 openapi_json,
92 options,
93 })
94 }
95
96 fn matches_swagger_path(&self, path: &str) -> bool {
98 path == self.ui_path
99 || path.starts_with(&format!("{}/", self.ui_path))
100 || path == self.api_doc_path
101 }
102
103 async fn handle_swagger_request(&self, path: &str) -> Result<Response> {
105 if path == self.api_doc_path {
106 self.handle_openapi_json().await
107 } else if path == self.ui_path {
108 self.handle_ui_redirect().await
109 } else {
110 self.handle_ui_resource(path).await
111 }
112 }
113
114 async fn handle_openapi_json(&self) -> Result<Response> {
116 let mut response = Response::empty();
117 response.set_status(StatusCode::OK);
118 response.set_header(
119 http::header::CONTENT_TYPE,
120 http::HeaderValue::from_static("application/json; charset=utf-8"),
121 );
122 response.set_header(
123 http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
124 http::HeaderValue::from_static("*"),
125 );
126 response.set_body(self.openapi_json.clone().into());
127 Ok(response)
128 }
129
130 async fn handle_ui_redirect(&self) -> Result<Response> {
132 let redirect_url = format!("{}/", self.ui_path);
133 let mut response = Response::empty();
134 response.set_status(StatusCode::MOVED_PERMANENTLY);
135 response.set_header(
136 http::header::LOCATION,
137 http::HeaderValue::from_str(&redirect_url)
138 .unwrap_or_else(|_| http::HeaderValue::from_static("/")),
139 );
140 Ok(response)
141 }
142
143 async fn handle_ui_resource(&self, path: &str) -> Result<Response> {
145 let relative_path = path
146 .strip_prefix(&format!("{}/", self.ui_path))
147 .unwrap_or("");
148
149 if relative_path.is_empty() || relative_path == "index.html" {
150 self.serve_swagger_ui_index().await
151 } else {
152 crate::ui_html::serve_asset(relative_path)
153 }
154 }
155
156 async fn serve_swagger_ui_index(&self) -> Result<Response> {
158 let html =
159 crate::ui_html::generate_index_html(&self.ui_path, &self.api_doc_path, &self.options);
160
161 let mut response = Response::empty();
162 response.set_status(StatusCode::OK);
163 response.set_header(
164 http::header::CONTENT_TYPE,
165 http::HeaderValue::from_static("text/html; charset=utf-8"),
166 );
167 response.set_header(
168 http::header::CACHE_CONTROL,
169 http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
170 );
171 response.set_body(html.into());
172 Ok(response)
173 }
174}
175
176#[async_trait]
177impl MiddleWareHandler for SwaggerUiMiddleware {
178 async fn handle(&self, req: Request, next: &Next) -> silent::Result<Response> {
180 let path = req.uri().path();
181 if self.matches_swagger_path(path) {
182 match self.handle_swagger_request(path).await {
183 Ok(response) => Ok(response),
184 Err(e) => {
185 eprintln!("Swagger UI中间件处理错误: {}", e);
186 let mut response = Response::empty();
188 response.set_status(StatusCode::INTERNAL_SERVER_ERROR);
189 response.set_body(format!("Swagger UI Error: {}", e).into());
190 Ok(response)
191 }
192 }
193 } else {
194 next.call(req).await
195 }
196 }
197}
198
199pub fn add_swagger_ui(
224 route: silent::prelude::Route,
225 ui_path: &str,
226 openapi: OpenApi,
227) -> silent::prelude::Route {
228 match SwaggerUiMiddleware::new(ui_path, openapi) {
229 Ok(middleware) => route.hook(middleware),
230 Err(e) => {
231 eprintln!("创建Swagger UI中间件失败: {}", e);
232 route
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use utoipa::OpenApi;
241
242 #[derive(OpenApi)]
243 #[openapi(
244 info(title = "Test API", version = "1.0.0"),
245 paths(),
246 components(schemas())
247 )]
248 struct TestApiDoc;
249
250 #[test]
251 fn test_middleware_creation() {
252 let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi());
253 assert!(middleware.is_ok());
254
255 let middleware = middleware.unwrap();
256 assert_eq!(middleware.ui_path, "/docs");
257 assert_eq!(middleware.api_doc_path, "/docs/openapi.json");
258 }
259
260 #[test]
261 fn test_path_matching() {
262 let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
263
264 assert!(middleware.matches_swagger_path("/docs"));
265 assert!(middleware.matches_swagger_path("/docs/"));
266 assert!(middleware.matches_swagger_path("/docs/index.html"));
267 assert!(middleware.matches_swagger_path("/docs/openapi.json"));
268 assert!(!middleware.matches_swagger_path("/api/users"));
269 assert!(!middleware.matches_swagger_path("/doc"));
270 }
271
272 #[tokio::test]
273 async fn test_openapi_json_handling() {
274 let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
275 let response = middleware.handle_openapi_json().await.unwrap();
276
277 let content_type = response.headers().get(http::header::CONTENT_TYPE);
279 assert!(content_type.is_some());
280 assert!(
282 response
283 .headers()
284 .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
285 .is_some()
286 );
287 }
288
289 #[tokio::test]
290 async fn test_redirect_on_base_path() {
291 let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
292 let resp = middleware.handle_swagger_request("/docs").await.unwrap();
293 assert!(resp.headers().get(http::header::LOCATION).is_some());
295 }
296
297 #[tokio::test]
298 async fn test_custom_api_doc_path() {
299 let mw = SwaggerUiMiddleware::with_custom_api_doc_path(
300 "/docs",
301 "/openapi-docs.json",
302 TestApiDoc::openapi(),
303 )
304 .unwrap();
305 assert!(mw.matches_swagger_path("/openapi-docs.json"));
307 let resp = mw
308 .handle_swagger_request("/openapi-docs.json")
309 .await
310 .unwrap();
311 assert!(
312 resp.headers()
313 .get(http::header::CONTENT_TYPE)
314 .map(|v| v.to_str().unwrap_or("").contains("application/json"))
315 .unwrap_or(false)
316 );
317 }
318
319 #[tokio::test]
320 async fn test_non_match_request_path() {
321 let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
322 assert!(!mw.matches_swagger_path("/other"));
323 }
324
325 #[tokio::test]
326 async fn test_asset_404_branch() {
327 let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
328 let resp = mw.handle_swagger_request("/docs/app.css").await.unwrap();
329 assert!(resp.headers().get(http::header::LOCATION).is_none());
331 }
332
333 #[tokio::test]
334 async fn test_index_html_headers() {
335 let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
336 let resp = mw.handle_swagger_request("/docs/index.html").await.unwrap();
337 let ct = resp.headers().get(http::header::CONTENT_TYPE).unwrap();
338 assert!(ct.to_str().unwrap_or("").contains("text/html"));
339 assert!(resp.headers().get(http::header::CACHE_CONTROL).is_some());
340 }
341}
342
343