1use axum::http::HeaderName;
2use axum::{
3 extract::Request,
4 http::{HeaderValue, StatusCode},
5 middleware::Next,
6 response::Response,
7};
8use std::time::Instant;
9use tower_http::cors::{Any, CorsLayer};
10
11use crate::config::Config;
12
13pub fn create_cors_layer(config: &Config) -> CorsLayer {
19 let allow_credentials = config.cors.allow_credentials;
20
21 let mut cors = CorsLayer::new().allow_methods(
25 config
26 .cors
27 .allowed_methods
28 .iter()
29 .map(|m| m.parse().unwrap())
30 .collect::<Vec<_>>(),
31 );
32
33 if config.cors.allowed_headers.contains(&"*".to_string()) {
35 if allow_credentials {
36 cors = cors.allow_headers([
38 HeaderName::from_static("content-type"),
39 HeaderName::from_static("authorization"),
40 HeaderName::from_static("x-requested-with"),
41 HeaderName::from_static("accept"),
42 HeaderName::from_static("origin"),
43 ]);
44 } else {
45 cors = cors.allow_headers(Any);
46 }
47 } else {
48 let headers: Result<Vec<_>, _> = config
50 .cors
51 .allowed_headers
52 .iter()
53 .map(|h| HeaderName::from_bytes(h.as_bytes()))
54 .collect();
55 if let Ok(headers) = headers {
56 cors = cors.allow_headers(headers);
57 }
58 }
59
60 if config.cors.allowed_origins.contains(&"*".to_string()) {
63 if allow_credentials {
64 tracing::warn!(
65 "CORS: allow_credentials=true 与 allowed_origins=* 不兼容,已自动禁用 allow_credentials"
66 );
67 }
68 cors = cors.allow_origin(Any).allow_credentials(false);
69 } else {
70 let origins: Result<Vec<_>, _> = config
71 .cors
72 .allowed_origins
73 .iter()
74 .map(|o| o.parse())
75 .collect();
76 if let Ok(origins) = origins {
77 cors = cors.allow_origin(origins.into_iter().collect::<Vec<_>>());
78 }
79 cors = cors.allow_credentials(allow_credentials);
80 }
81
82 cors
83}
84
85fn parse_trace_id(traceparent: &str) -> Option<String> {
92 let parts: Vec<&str> = traceparent.split('-').collect();
93 if parts.len() >= 3 && parts[1].len() == 32 {
94 Some(parts[1].to_string())
95 } else {
96 None
97 }
98}
99
100fn generate_trace_id() -> String {
102 let id = uuid::Uuid::new_v4();
103 id.as_simple().to_string() }
105
106fn generate_span_id() -> String {
108 let id = uuid::Uuid::new_v4();
109 id.as_simple().to_string()[..16].to_string()
110}
111
112fn should_skip_logging(path: &str) -> bool {
114 matches!(path, "/" | "/health" | "/favicon.ico")
115}
116
117fn slow_request_level(duration_ms: u128) -> Option<&'static str> {
120 match duration_ms {
121 10_000.. => Some("🔴 CRITICAL >10s"),
122 5_000..=9_999 => Some("🔴 VERY_SLOW >5s"),
123 3_000..=4_999 => Some("🟠 SLOW >3s"),
124 2_000..=2_999 => Some("🟠 SLOW >2s"),
125 1_000..=1_999 => Some("🟡 SLOW >1s"),
126 500..=999 => Some("🟡 SLOW >500ms"),
127 200..=499 => Some("🟢 SLOW >200ms"),
128 _ => None,
129 }
130}
131
132pub async fn request_logging_middleware(
145 request: Request,
146 next: Next,
147) -> Result<Response, StatusCode> {
148 let method = request.method().clone();
149 let uri = request.uri().path().to_string();
150 let version = format!("{:?}", request.version());
151
152 if should_skip_logging(&uri) {
154 return Ok(next.run(request).await);
155 }
156
157 let trace_id = request
159 .headers()
160 .get("traceparent")
161 .and_then(|v| v.to_str().ok())
162 .and_then(parse_trace_id)
163 .unwrap_or_else(generate_trace_id);
164
165 let trace_short = &trace_id[..8.min(trace_id.len())];
167
168 let span_id = generate_span_id();
170 let span_short = &span_id[..8.min(span_id.len())];
171
172 let user_agent = request
174 .headers()
175 .get("user-agent")
176 .and_then(|v| v.to_str().ok())
177 .map(|ua| {
178 if ua.len() > 50 {
180 format!("{}...", &ua[..50])
181 } else {
182 ua.to_string()
183 }
184 })
185 .unwrap_or_default();
186
187 let start = Instant::now();
188
189 tracing::info!(
191 "→ {} {} {} [trace={}|span={}]",
192 method,
193 uri,
194 version,
195 trace_short,
196 span_short,
197 );
198
199 let response = next.run(request).await;
201 let duration = start.elapsed();
202 let duration_ms = duration.as_millis();
203 let status = response.status();
204
205 let traceparent_value = format!("00-{}-{}-01", trace_id, span_id);
207
208 if status.is_server_error() {
210 tracing::error!(
211 "← {} {} {} {}ms [trace={}|span={}] ua=\"{}\"",
212 status.as_u16(),
213 method,
214 uri,
215 duration_ms,
216 trace_short,
217 span_short,
218 user_agent,
219 );
220 } else if status.is_client_error() {
221 tracing::warn!(
222 "← {} {} {} {}ms [trace={}|span={}]",
223 status.as_u16(),
224 method,
225 uri,
226 duration_ms,
227 trace_short,
228 span_short,
229 );
230 } else {
231 tracing::info!(
232 "← {} {} {} {}ms [trace={}|span={}]",
233 status.as_u16(),
234 method,
235 uri,
236 duration_ms,
237 trace_short,
238 span_short,
239 );
240 }
241
242 if let Some(level) = slow_request_level(duration_ms) {
244 tracing::warn!(
245 "{} {}ms {} {} [trace={}|span={}] ua=\"{}\"",
246 level,
247 duration_ms,
248 method,
249 uri,
250 trace_short,
251 span_short,
252 user_agent,
253 );
254 }
255
256 let mut response = response;
258 if let Ok(v) = HeaderValue::from_str(&traceparent_value) {
259 response.headers_mut().insert("traceparent", v);
260 }
261 if let Ok(v) = HeaderValue::from_str(trace_short) {
262 response
263 .headers_mut()
264 .insert("x-trace-id", v);
265 }
266
267 Ok(response)
268}
269
270pub fn grpc_log_request(req: &axum::http::Request<impl std::any::Any>) -> (String, String, Instant) {
277 let path = req.uri().path().to_string();
278
279 let trace_id = req
281 .headers()
282 .get("traceparent")
283 .and_then(|v| v.to_str().ok())
284 .and_then(parse_trace_id)
285 .unwrap_or_else(generate_trace_id);
286
287 let trace_short = trace_id[..8.min(trace_id.len())].to_string();
288 let span_id = generate_span_id();
289 let span_short = span_id[..8.min(span_id.len())].to_string();
290
291 tracing::info!(
292 "→ gRPC {} [trace={}|span={}]",
293 path,
294 trace_short,
295 span_short,
296 );
297
298 (trace_short, span_short, Instant::now())
299}
300
301pub fn grpc_log_response(
303 path: &str,
304 status: StatusCode,
305 trace_short: &str,
306 span_short: &str,
307 start: Instant,
308) {
309 let duration_ms = start.elapsed().as_millis();
310
311 if status.is_success() {
312 tracing::info!(
313 "← gRPC {} {} {}ms [trace={}|span={}]",
314 status.as_u16(),
315 path,
316 duration_ms,
317 trace_short,
318 span_short,
319 );
320 } else {
321 tracing::warn!(
322 "← gRPC {} {} {}ms [trace={}|span={}]",
323 status.as_u16(),
324 path,
325 duration_ms,
326 trace_short,
327 span_short,
328 );
329 }
330
331 if let Some(level) = slow_request_level(duration_ms) {
333 tracing::warn!(
334 "{} {}ms gRPC {} [trace={}|span={}]",
335 level,
336 duration_ms,
337 path,
338 trace_short,
339 span_short,
340 );
341 }
342}