Skip to main content

fbc_starter/http/
middleware.rs

1use axum::http::HeaderName;
2use axum::{
3    extract::Request,
4    http::{HeaderValue, StatusCode},
5    middleware::Next,
6    response::Response,
7};
8use std::time::Instant;
9use tower_http::cors::{Any, CorsLayer};
10
11use crate::config::Config;
12
13// =====================================================================
14// CORS 中间件
15// =====================================================================
16
17/// 创建 CORS 中间件
18pub fn create_cors_layer(config: &Config) -> CorsLayer {
19    let allow_credentials = config.cors.allow_credentials;
20
21    // 当 allow_credentials 为 true 时,不能使用 * 作为 allowed_headers 或 allowed_origins
22    // 需要明确指定允许的值
23
24    let mut cors = CorsLayer::new().allow_methods(
25        config
26            .cors
27            .allowed_methods
28            .iter()
29            .map(|m| m.parse().unwrap())
30            .collect::<Vec<_>>(),
31    );
32
33    // 处理 allowed_headers
34    if config.cors.allowed_headers.contains(&"*".to_string()) {
35        if allow_credentials {
36            // 当允许凭证时,使用常见的请求头列表
37            cors = cors.allow_headers([
38                HeaderName::from_static("content-type"),
39                HeaderName::from_static("authorization"),
40                HeaderName::from_static("x-requested-with"),
41                HeaderName::from_static("accept"),
42                HeaderName::from_static("origin"),
43            ]);
44        } else {
45            cors = cors.allow_headers(Any);
46        }
47    } else {
48        // 解析指定的请求头
49        let headers: Result<Vec<_>, _> = config
50            .cors
51            .allowed_headers
52            .iter()
53            .map(|h| HeaderName::from_bytes(h.as_bytes()))
54            .collect();
55        if let Ok(headers) = headers {
56            cors = cors.allow_headers(headers);
57        }
58    }
59
60    // 处理 allowed_origins
61    // 注意:当 allowed_origins 为 * 时,allow_credentials 必须为 false(CORS 规范要求)
62    if config.cors.allowed_origins.contains(&"*".to_string()) {
63        if allow_credentials {
64            tracing::warn!(
65                "CORS: allow_credentials=true 与 allowed_origins=* 不兼容,已自动禁用 allow_credentials"
66            );
67        }
68        cors = cors.allow_origin(Any).allow_credentials(false);
69    } else {
70        let origins: Result<Vec<_>, _> = config
71            .cors
72            .allowed_origins
73            .iter()
74            .map(|o| o.parse())
75            .collect();
76        if let Ok(origins) = origins {
77            cors = cors.allow_origin(origins.into_iter().collect::<Vec<_>>());
78        }
79        cors = cors.allow_credentials(allow_credentials);
80    }
81
82    cors
83}
84
85// =====================================================================
86// Trace Context 工具函数(W3C traceparent 标准)
87// =====================================================================
88
89/// 从 traceparent header 解析 trace_id
90/// 格式: 00-{trace_id(32hex)}-{parent_span_id(16hex)}-{flags(2hex)}
91fn parse_trace_id(traceparent: &str) -> Option<String> {
92    let parts: Vec<&str> = traceparent.split('-').collect();
93    if parts.len() >= 3 && parts[1].len() == 32 {
94        Some(parts[1].to_string())
95    } else {
96        None
97    }
98}
99
100/// 生成 trace_id(32位 hex = 128 bit,符合 W3C 标准)
101fn generate_trace_id() -> String {
102    let id = uuid::Uuid::new_v4();
103    id.as_simple().to_string() // 32 hex chars without hyphens
104}
105
106/// 生成 span_id(16位 hex = 64 bit,符合 W3C 标准)
107fn generate_span_id() -> String {
108    let id = uuid::Uuid::new_v4();
109    id.as_simple().to_string()[..16].to_string()
110}
111
112/// 判断路径是否应跳过日志记录
113fn should_skip_logging(path: &str) -> bool {
114    matches!(path, "/" | "/health" | "/favicon.ico")
115}
116
117/// 根据耗时获取慢请求级别标签
118/// 返回 (级别标签, 是否需要告警)
119fn slow_request_level(duration_ms: u128) -> Option<&'static str> {
120    match duration_ms {
121        10_000.. => Some("🔴 CRITICAL >10s"),
122        5_000..=9_999 => Some("🔴 VERY_SLOW >5s"),
123        3_000..=4_999 => Some("🟠 SLOW >3s"),
124        2_000..=2_999 => Some("🟠 SLOW >2s"),
125        1_000..=1_999 => Some("🟡 SLOW >1s"),
126        500..=999 => Some("🟡 SLOW >500ms"),
127        200..=499 => Some("🟢 SLOW >200ms"),
128        _ => None,
129    }
130}
131
132// =====================================================================
133// 生产级 HTTP 请求日志中间件
134// =====================================================================
135
136/// 生产级 HTTP 请求日志中间件
137///
138/// 功能:
139/// - 提取/生成 W3C traceparent trace_id(分布式链路追踪)
140/// - 记录请求方法、路径、状态码、耗时
141/// - 多级慢请求告警(200ms/500ms/1s/2s/3s/5s/10s)
142/// - 自动跳过健康检查等路径
143/// - trace_id/span_id 传递到下游(通过 response header)
144pub async fn request_logging_middleware(
145    request: Request,
146    next: Next,
147) -> Result<Response, StatusCode> {
148    let method = request.method().clone();
149    let uri = request.uri().path().to_string();
150    let version = format!("{:?}", request.version());
151
152    // 跳过健康检查等路径
153    if should_skip_logging(&uri) {
154        return Ok(next.run(request).await);
155    }
156
157    // 提取或生成 trace_id(从 traceparent header)
158    let trace_id = request
159        .headers()
160        .get("traceparent")
161        .and_then(|v| v.to_str().ok())
162        .and_then(parse_trace_id)
163        .unwrap_or_else(generate_trace_id);
164
165    // 取 trace_id 前 8 位用于简短显示
166    let trace_short = &trace_id[..8.min(trace_id.len())];
167
168    // 生成本服务的 span_id
169    let span_id = generate_span_id();
170    let span_short = &span_id[..8.min(span_id.len())];
171
172    // 提取 user-agent(简化)
173    let user_agent = request
174        .headers()
175        .get("user-agent")
176        .and_then(|v| v.to_str().ok())
177        .map(|ua| {
178            // 截取前 50 个字符
179            if ua.len() > 50 {
180                format!("{}...", &ua[..50])
181            } else {
182                ua.to_string()
183            }
184        })
185        .unwrap_or_default();
186
187    let start = Instant::now();
188
189    // 记录请求开始
190    tracing::info!(
191        "→ {} {} {} [trace={}|span={}]",
192        method,
193        uri,
194        version,
195        trace_short,
196        span_short,
197    );
198
199    // 执行请求
200    let response = next.run(request).await;
201    let duration = start.elapsed();
202    let duration_ms = duration.as_millis();
203    let status = response.status();
204
205    // 构建 traceparent 用于传递(带本地 span_id)
206    let traceparent_value = format!("00-{}-{}-01", trace_id, span_id);
207
208    // 记录请求完成(含状态码和耗时)
209    if status.is_server_error() {
210        tracing::error!(
211            "← {} {} {} {}ms [trace={}|span={}] ua=\"{}\"",
212            status.as_u16(),
213            method,
214            uri,
215            duration_ms,
216            trace_short,
217            span_short,
218            user_agent,
219        );
220    } else if status.is_client_error() {
221        tracing::warn!(
222            "← {} {} {} {}ms [trace={}|span={}]",
223            status.as_u16(),
224            method,
225            uri,
226            duration_ms,
227            trace_short,
228            span_short,
229        );
230    } else {
231        tracing::info!(
232            "← {} {} {} {}ms [trace={}|span={}]",
233            status.as_u16(),
234            method,
235            uri,
236            duration_ms,
237            trace_short,
238            span_short,
239        );
240    }
241
242    // 多级慢请求告警
243    if let Some(level) = slow_request_level(duration_ms) {
244        tracing::warn!(
245            "{} {}ms {} {} [trace={}|span={}] ua=\"{}\"",
246            level,
247            duration_ms,
248            method,
249            uri,
250            trace_short,
251            span_short,
252            user_agent,
253        );
254    }
255
256    // 将 traceparent 注入响应头(方便客户端和前端调试)
257    let mut response = response;
258    if let Ok(v) = HeaderValue::from_str(&traceparent_value) {
259        response.headers_mut().insert("traceparent", v);
260    }
261    if let Ok(v) = HeaderValue::from_str(trace_short) {
262        response
263            .headers_mut()
264            .insert("x-trace-id", v);
265    }
266
267    Ok(response)
268}
269
270// =====================================================================
271// gRPC 请求日志(供 server.rs 中 fallback_service 调用)
272// =====================================================================
273
274/// 记录 gRPC 请求日志(请求开始)
275/// 返回 (trace_id_short, span_id_short, start_time)
276pub fn grpc_log_request(req: &axum::http::Request<impl std::any::Any>) -> (String, String, Instant) {
277    let path = req.uri().path().to_string();
278
279    // 从 gRPC metadata 中提取 traceparent(HTTP/2 header)
280    let trace_id = req
281        .headers()
282        .get("traceparent")
283        .and_then(|v| v.to_str().ok())
284        .and_then(parse_trace_id)
285        .unwrap_or_else(generate_trace_id);
286
287    let trace_short = trace_id[..8.min(trace_id.len())].to_string();
288    let span_id = generate_span_id();
289    let span_short = span_id[..8.min(span_id.len())].to_string();
290
291    tracing::info!(
292        "→ gRPC {} [trace={}|span={}]",
293        path,
294        trace_short,
295        span_short,
296    );
297
298    (trace_short, span_short, Instant::now())
299}
300
301/// 记录 gRPC 请求日志(请求完成)
302pub fn grpc_log_response(
303    path: &str,
304    status: StatusCode,
305    trace_short: &str,
306    span_short: &str,
307    start: Instant,
308) {
309    let duration_ms = start.elapsed().as_millis();
310
311    if status.is_success() {
312        tracing::info!(
313            "← gRPC {} {} {}ms [trace={}|span={}]",
314            status.as_u16(),
315            path,
316            duration_ms,
317            trace_short,
318            span_short,
319        );
320    } else {
321        tracing::warn!(
322            "← gRPC {} {} {}ms [trace={}|span={}]",
323            status.as_u16(),
324            path,
325            duration_ms,
326            trace_short,
327            span_short,
328        );
329    }
330
331    // 多级慢请求告警
332    if let Some(level) = slow_request_level(duration_ms) {
333        tracing::warn!(
334            "{} {}ms gRPC {} [trace={}|span={}]",
335            level,
336            duration_ms,
337            path,
338            trace_short,
339            span_short,
340        );
341    }
342}