fbc_starter/http/
middleware.rs

1use axum::http::HeaderName;
2use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
3use tower_http::{
4    cors::{Any, CorsLayer},
5    trace::TraceLayer,
6};
7
8use crate::config::Config;
9
10/// 创建 CORS 中间件
11pub fn create_cors_layer(config: &Config) -> CorsLayer {
12    let allow_credentials = config.cors.allow_credentials;
13
14    // 当 allow_credentials 为 true 时,不能使用 * 作为 allowed_headers 或 allowed_origins
15    // 需要明确指定允许的值
16
17    let mut cors = CorsLayer::new().allow_methods(
18        config
19            .cors
20            .allowed_methods
21            .iter()
22            .map(|m| m.parse().unwrap())
23            .collect::<Vec<_>>(),
24    );
25
26    // 处理 allowed_headers
27    if config.cors.allowed_headers.contains(&"*".to_string()) {
28        if allow_credentials {
29            // 当允许凭证时,使用常见的请求头列表
30            cors = cors.allow_headers([
31                HeaderName::from_static("content-type"),
32                HeaderName::from_static("authorization"),
33                HeaderName::from_static("x-requested-with"),
34                HeaderName::from_static("accept"),
35                HeaderName::from_static("origin"),
36            ]);
37        } else {
38            cors = cors.allow_headers(Any);
39        }
40    } else {
41        // 解析指定的请求头
42        let headers: Result<Vec<_>, _> = config
43            .cors
44            .allowed_headers
45            .iter()
46            .map(|h| HeaderName::from_bytes(h.as_bytes()))
47            .collect();
48        if let Ok(headers) = headers {
49            cors = cors.allow_headers(headers);
50        }
51    }
52
53    // 处理 allowed_origins
54    // 注意:当 allowed_origins 为 * 时,allow_credentials 必须为 false(CORS 规范要求)
55    if config.cors.allowed_origins.contains(&"*".to_string()) {
56        if allow_credentials {
57            tracing::warn!(
58                "CORS: allow_credentials=true 与 allowed_origins=* 不兼容,已自动禁用 allow_credentials"
59            );
60        }
61        cors = cors.allow_origin(Any).allow_credentials(false);
62    } else {
63        let origins: Result<Vec<_>, _> = config
64            .cors
65            .allowed_origins
66            .iter()
67            .map(|o| o.parse())
68            .collect();
69        if let Ok(origins) = origins {
70            cors = cors.allow_origin(origins.into_iter().collect::<Vec<_>>());
71        }
72        cors = cors.allow_credentials(allow_credentials);
73    }
74
75    cors
76}
77
78/// 创建追踪中间件(使用默认配置)
79pub fn create_trace_layer(
80) -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>>
81{
82    TraceLayer::new_for_http()
83}
84
85/// 请求日志中间件
86pub async fn request_logger_middleware(
87    request: Request,
88    next: Next,
89) -> Result<Response, StatusCode> {
90    let method = request.method().clone();
91    let uri = request.uri().clone();
92    let start = std::time::Instant::now();
93
94    tracing::info!("{} {}", method, uri);
95
96    let response = next.run(request).await;
97    let duration = start.elapsed();
98
99    tracing::info!(
100        "{} {} - {} - {:?}",
101        method,
102        uri,
103        response.status(),
104        duration
105    );
106
107    Ok(response)
108}