Skip to main content

silent_openapi/
middleware.rs

1//! Swagger UI 中间件
2//!
3//! 提供中间件形式的Swagger UI支持,可以更灵活地集成到现有路由中。
4
5use crate::{OpenApiError, Result, SwaggerUiOptions};
6use async_trait::async_trait;
7use silent::{Handler, MiddleWareHandler, Next, Request, Response, StatusCode};
8use utoipa::openapi::OpenApi;
9
10/// Swagger UI 中间件
11///
12/// 实现了Silent的MiddleWareHandler trait,可以作为中间件添加到路由中。
13/// 当请求匹配Swagger UI相关路径时,直接返回响应;否则继续执行后续处理器。
14#[derive(Clone)]
15pub struct SwaggerUiMiddleware {
16    /// Swagger UI的基础路径
17    ui_path: String,
18    /// OpenAPI JSON的路径
19    api_doc_path: String,
20    /// OpenAPI 规范的JSON字符串
21    openapi_json: String,
22    /// UI 配置
23    options: SwaggerUiOptions,
24}
25
26impl SwaggerUiMiddleware {
27    /// 创建新的Swagger UI中间件
28    ///
29    /// # 参数
30    ///
31    /// - `ui_path`: Swagger UI的访问路径,如 "/swagger-ui"
32    /// - `openapi`: OpenAPI规范对象
33    ///
34    /// # 示例
35    ///
36    /// ```ignore
37    /// use silent::prelude::*;
38    /// use silent_openapi::SwaggerUiMiddleware;
39    /// use utoipa::OpenApi;
40    ///
41    /// #[derive(OpenApi)]
42    /// #[openapi(paths(), components(schemas()))]
43    /// struct ApiDoc;
44    ///
45    /// let middleware = SwaggerUiMiddleware::new("/swagger-ui", ApiDoc::openapi());
46    ///
47    /// let route = Route::new("")
48    ///     .hook(middleware)
49    ///     .get(your_handler);
50    /// ```
51    pub fn new(ui_path: &str, openapi: OpenApi) -> Result<Self> {
52        let api_doc_path = format!("{}/openapi.json", ui_path.trim_end_matches('/'));
53        let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
54
55        Ok(Self {
56            ui_path: ui_path.to_string(),
57            api_doc_path,
58            openapi_json,
59            options: SwaggerUiOptions::default(),
60        })
61    }
62
63    /// 使用自定义的API文档路径
64    pub fn with_custom_api_doc_path(
65        ui_path: &str,
66        api_doc_path: &str,
67        openapi: OpenApi,
68    ) -> Result<Self> {
69        let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
70
71        Ok(Self {
72            ui_path: ui_path.to_string(),
73            api_doc_path: api_doc_path.to_string(),
74            openapi_json,
75            options: SwaggerUiOptions::default(),
76        })
77    }
78
79    /// 使用自定义选项创建中间件
80    pub fn with_options(
81        ui_path: &str,
82        openapi: OpenApi,
83        options: SwaggerUiOptions,
84    ) -> Result<Self> {
85        let api_doc_path = format!("{}/openapi.json", ui_path.trim_end_matches('/'));
86        let openapi_json = serde_json::to_string_pretty(&openapi).map_err(OpenApiError::Json)?;
87
88        Ok(Self {
89            ui_path: ui_path.to_string(),
90            api_doc_path,
91            openapi_json,
92            options,
93        })
94    }
95
96    /// 检查请求路径是否匹配Swagger UI相关路径
97    fn matches_swagger_path(&self, path: &str) -> bool {
98        path == self.ui_path
99            || path.starts_with(&format!("{}/", self.ui_path))
100            || path == self.api_doc_path
101    }
102
103    /// 处理Swagger UI相关请求
104    async fn handle_swagger_request(&self, path: &str) -> Result<Response> {
105        if path == self.api_doc_path {
106            self.handle_openapi_json().await
107        } else if path == self.ui_path {
108            self.handle_ui_redirect().await
109        } else {
110            self.handle_ui_resource(path).await
111        }
112    }
113
114    /// 处理OpenAPI JSON请求
115    async fn handle_openapi_json(&self) -> Result<Response> {
116        let mut response = Response::empty();
117        response.set_status(StatusCode::OK);
118        response.set_header(
119            http::header::CONTENT_TYPE,
120            http::HeaderValue::from_static("application/json; charset=utf-8"),
121        );
122        response.set_header(
123            http::header::ACCESS_CONTROL_ALLOW_ORIGIN,
124            http::HeaderValue::from_static("*"),
125        );
126        response.set_body(self.openapi_json.clone().into());
127        Ok(response)
128    }
129
130    /// 处理UI主页重定向
131    async fn handle_ui_redirect(&self) -> Result<Response> {
132        let redirect_url = format!("{}/", self.ui_path);
133        let mut response = Response::empty();
134        response.set_status(StatusCode::MOVED_PERMANENTLY);
135        response.set_header(
136            http::header::LOCATION,
137            http::HeaderValue::from_str(&redirect_url)
138                .unwrap_or_else(|_| http::HeaderValue::from_static("/")),
139        );
140        Ok(response)
141    }
142
143    /// 处理UI资源请求
144    async fn handle_ui_resource(&self, path: &str) -> Result<Response> {
145        let relative_path = path
146            .strip_prefix(&format!("{}/", self.ui_path))
147            .unwrap_or("");
148
149        if relative_path.is_empty() || relative_path == "index.html" {
150            self.serve_swagger_ui_index().await
151        } else {
152            crate::ui_html::serve_asset(relative_path)
153        }
154    }
155
156    /// 生成Swagger UI主页HTML
157    async fn serve_swagger_ui_index(&self) -> Result<Response> {
158        let html =
159            crate::ui_html::generate_index_html(&self.ui_path, &self.api_doc_path, &self.options);
160
161        let mut response = Response::empty();
162        response.set_status(StatusCode::OK);
163        response.set_header(
164            http::header::CONTENT_TYPE,
165            http::HeaderValue::from_static("text/html; charset=utf-8"),
166        );
167        response.set_header(
168            http::header::CACHE_CONTROL,
169            http::HeaderValue::from_static("no-cache, no-store, must-revalidate"),
170        );
171        response.set_body(html.into());
172        Ok(response)
173    }
174}
175
176#[async_trait]
177impl MiddleWareHandler for SwaggerUiMiddleware {
178    /// 处理请求:命中 Swagger 相关路径则拦截返回,否则交由下一个处理器
179    async fn handle(&self, req: Request, next: &Next) -> silent::Result<Response> {
180        let path = req.uri().path();
181        if self.matches_swagger_path(path) {
182            match self.handle_swagger_request(path).await {
183                Ok(response) => Ok(response),
184                Err(e) => {
185                    eprintln!("Swagger UI中间件处理错误: {}", e);
186                    // 返回适当的错误响应
187                    let mut response = Response::empty();
188                    response.set_status(StatusCode::INTERNAL_SERVER_ERROR);
189                    response.set_body(format!("Swagger UI Error: {}", e).into());
190                    Ok(response)
191                }
192            }
193        } else {
194            next.call(req).await
195        }
196    }
197}
198
199/// 便捷函数:创建Swagger UI中间件并添加到路由
200///
201/// # 参数
202///
203/// - `route`: 要添加中间件的路由
204/// - `ui_path`: Swagger UI的访问路径
205/// - `openapi`: OpenAPI规范对象
206///
207/// # 示例
208///
209/// ```ignore
210/// use silent::prelude::*;
211/// use silent_openapi::add_swagger_ui;
212/// use utoipa::OpenApi;
213///
214/// #[derive(OpenApi)]
215/// #[openapi(paths(), components(schemas()))]
216/// struct ApiDoc;
217///
218/// let route = Route::new("api")
219///     .get(some_handler);
220///
221/// let route_with_swagger = add_swagger_ui(route, "/docs", ApiDoc::openapi());
222/// ```
223pub fn add_swagger_ui(
224    route: silent::prelude::Route,
225    ui_path: &str,
226    openapi: OpenApi,
227) -> silent::prelude::Route {
228    match SwaggerUiMiddleware::new(ui_path, openapi) {
229        Ok(middleware) => route.hook(middleware),
230        Err(e) => {
231            eprintln!("创建Swagger UI中间件失败: {}", e);
232            route
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use utoipa::OpenApi;
241
242    #[derive(OpenApi)]
243    #[openapi(
244        info(title = "Test API", version = "1.0.0"),
245        paths(),
246        components(schemas())
247    )]
248    struct TestApiDoc;
249
250    #[test]
251    fn test_middleware_creation() {
252        let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi());
253        assert!(middleware.is_ok());
254
255        let middleware = middleware.unwrap();
256        assert_eq!(middleware.ui_path, "/docs");
257        assert_eq!(middleware.api_doc_path, "/docs/openapi.json");
258    }
259
260    #[test]
261    fn test_path_matching() {
262        let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
263
264        assert!(middleware.matches_swagger_path("/docs"));
265        assert!(middleware.matches_swagger_path("/docs/"));
266        assert!(middleware.matches_swagger_path("/docs/index.html"));
267        assert!(middleware.matches_swagger_path("/docs/openapi.json"));
268        assert!(!middleware.matches_swagger_path("/api/users"));
269        assert!(!middleware.matches_swagger_path("/doc"));
270    }
271
272    #[tokio::test]
273    async fn test_openapi_json_handling() {
274        let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
275        let response = middleware.handle_openapi_json().await.unwrap();
276
277        // 验证Content-Type头(Silent Response没有public的status方法)
278        let content_type = response.headers().get(http::header::CONTENT_TYPE);
279        assert!(content_type.is_some());
280        // 验证CORS头
281        assert!(
282            response
283                .headers()
284                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
285                .is_some()
286        );
287    }
288
289    #[tokio::test]
290    async fn test_redirect_on_base_path() {
291        let middleware = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
292        let resp = middleware.handle_swagger_request("/docs").await.unwrap();
293        // 无法读取状态码,验证是否存在 LOCATION 头以确认重定向
294        assert!(resp.headers().get(http::header::LOCATION).is_some());
295    }
296
297    #[tokio::test]
298    async fn test_custom_api_doc_path() {
299        let mw = SwaggerUiMiddleware::with_custom_api_doc_path(
300            "/docs",
301            "/openapi-docs.json",
302            TestApiDoc::openapi(),
303        )
304        .unwrap();
305        // 自定义路径匹配
306        assert!(mw.matches_swagger_path("/openapi-docs.json"));
307        let resp = mw
308            .handle_swagger_request("/openapi-docs.json")
309            .await
310            .unwrap();
311        assert!(
312            resp.headers()
313                .get(http::header::CONTENT_TYPE)
314                .map(|v| v.to_str().unwrap_or("").contains("application/json"))
315                .unwrap_or(false)
316        );
317    }
318
319    #[tokio::test]
320    async fn test_non_match_request_path() {
321        let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
322        assert!(!mw.matches_swagger_path("/other"));
323    }
324
325    #[tokio::test]
326    async fn test_asset_404_branch() {
327        let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
328        let resp = mw.handle_swagger_request("/docs/app.css").await.unwrap();
329        // 不应是重定向
330        assert!(resp.headers().get(http::header::LOCATION).is_none());
331    }
332
333    #[tokio::test]
334    async fn test_index_html_headers() {
335        let mw = SwaggerUiMiddleware::new("/docs", TestApiDoc::openapi()).unwrap();
336        let resp = mw.handle_swagger_request("/docs/index.html").await.unwrap();
337        let ct = resp.headers().get(http::header::CONTENT_TYPE).unwrap();
338        assert!(ct.to_str().unwrap_or("").contains("text/html"));
339        assert!(resp.headers().get(http::header::CACHE_CONTROL).is_some());
340    }
341}
342
343// 选项类型在 crate 根导出