use axum::http::HeaderName;
use axum::{
extract::Request,
http::{HeaderValue, StatusCode},
middleware::Next,
response::Response,
};
use std::time::Instant;
use tower_http::cors::{Any, CorsLayer};
use crate::config::Config;
pub fn create_cors_layer(config: &Config) -> CorsLayer {
let allow_credentials = config.cors.allow_credentials;
let mut cors = CorsLayer::new().allow_methods(
config
.cors
.allowed_methods
.iter()
.map(|m| m.parse().unwrap())
.collect::<Vec<_>>(),
);
if config.cors.allowed_headers.contains(&"*".to_string()) {
if allow_credentials {
cors = cors.allow_headers([
HeaderName::from_static("content-type"),
HeaderName::from_static("authorization"),
HeaderName::from_static("x-requested-with"),
HeaderName::from_static("accept"),
HeaderName::from_static("origin"),
]);
} else {
cors = cors.allow_headers(Any);
}
} else {
let headers: Result<Vec<_>, _> = config
.cors
.allowed_headers
.iter()
.map(|h| HeaderName::from_bytes(h.as_bytes()))
.collect();
if let Ok(headers) = headers {
cors = cors.allow_headers(headers);
}
}
if config.cors.allowed_origins.contains(&"*".to_string()) {
if allow_credentials {
tracing::warn!(
"CORS: allow_credentials=true 与 allowed_origins=* 不兼容,已自动禁用 allow_credentials"
);
}
cors = cors.allow_origin(Any).allow_credentials(false);
} else {
let origins: Result<Vec<_>, _> = config
.cors
.allowed_origins
.iter()
.map(|o| o.parse())
.collect();
if let Ok(origins) = origins {
cors = cors.allow_origin(origins.into_iter().collect::<Vec<_>>());
}
cors = cors.allow_credentials(allow_credentials);
}
cors
}
fn parse_trace_id(traceparent: &str) -> Option<String> {
let parts: Vec<&str> = traceparent.split('-').collect();
if parts.len() >= 3 && parts[1].len() == 32 {
Some(parts[1].to_string())
} else {
None
}
}
fn generate_trace_id() -> String {
let id = uuid::Uuid::new_v4();
id.as_simple().to_string() }
fn generate_span_id() -> String {
let id = uuid::Uuid::new_v4();
id.as_simple().to_string()[..16].to_string()
}
fn should_skip_logging(path: &str) -> bool {
matches!(path, "/" | "/health" | "/favicon.ico")
}
fn slow_request_level(duration_ms: u128) -> Option<&'static str> {
match duration_ms {
10_000.. => Some("🔴 CRITICAL >10s"),
5_000..=9_999 => Some("🔴 VERY_SLOW >5s"),
3_000..=4_999 => Some("🟠 SLOW >3s"),
2_000..=2_999 => Some("🟠 SLOW >2s"),
1_000..=1_999 => Some("🟡 SLOW >1s"),
500..=999 => Some("🟡 SLOW >500ms"),
200..=499 => Some("🟢 SLOW >200ms"),
_ => None,
}
}
pub async fn request_logging_middleware(
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let method = request.method().clone();
let uri = request.uri().path().to_string();
let version = format!("{:?}", request.version());
if should_skip_logging(&uri) {
return Ok(next.run(request).await);
}
let trace_id = request
.headers()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.and_then(parse_trace_id)
.unwrap_or_else(generate_trace_id);
let trace_short = &trace_id[..8.min(trace_id.len())];
let span_id = generate_span_id();
let span_short = &span_id[..8.min(span_id.len())];
let user_agent = request
.headers()
.get("user-agent")
.and_then(|v| v.to_str().ok())
.map(|ua| {
if ua.len() > 50 {
format!("{}...", &ua[..50])
} else {
ua.to_string()
}
})
.unwrap_or_default();
let start = Instant::now();
tracing::info!(
"→ {} {} {} [trace={}|span={}]",
method,
uri,
version,
trace_short,
span_short,
);
let response = next.run(request).await;
let duration = start.elapsed();
let duration_ms = duration.as_millis();
let status = response.status();
let traceparent_value = format!("00-{}-{}-01", trace_id, span_id);
if status.is_server_error() {
tracing::error!(
"← {} {} {} {}ms [trace={}|span={}] ua=\"{}\"",
status.as_u16(),
method,
uri,
duration_ms,
trace_short,
span_short,
user_agent,
);
} else if status.is_client_error() {
tracing::warn!(
"← {} {} {} {}ms [trace={}|span={}]",
status.as_u16(),
method,
uri,
duration_ms,
trace_short,
span_short,
);
} else {
tracing::info!(
"← {} {} {} {}ms [trace={}|span={}]",
status.as_u16(),
method,
uri,
duration_ms,
trace_short,
span_short,
);
}
if let Some(level) = slow_request_level(duration_ms) {
tracing::warn!(
"{} {}ms {} {} [trace={}|span={}] ua=\"{}\"",
level,
duration_ms,
method,
uri,
trace_short,
span_short,
user_agent,
);
}
let mut response = response;
if let Ok(v) = HeaderValue::from_str(&traceparent_value) {
response.headers_mut().insert("traceparent", v);
}
if let Ok(v) = HeaderValue::from_str(trace_short) {
response
.headers_mut()
.insert("x-trace-id", v);
}
Ok(response)
}
pub fn grpc_log_request(req: &axum::http::Request<impl std::any::Any>) -> (String, String, Instant) {
let path = req.uri().path().to_string();
let trace_id = req
.headers()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.and_then(parse_trace_id)
.unwrap_or_else(generate_trace_id);
let trace_short = trace_id[..8.min(trace_id.len())].to_string();
let span_id = generate_span_id();
let span_short = span_id[..8.min(span_id.len())].to_string();
tracing::info!(
"→ gRPC {} [trace={}|span={}]",
path,
trace_short,
span_short,
);
(trace_short, span_short, Instant::now())
}
pub fn grpc_log_response(
path: &str,
status: StatusCode,
trace_short: &str,
span_short: &str,
start: Instant,
) {
let duration_ms = start.elapsed().as_millis();
if status.is_success() {
tracing::info!(
"← gRPC {} {} {}ms [trace={}|span={}]",
status.as_u16(),
path,
duration_ms,
trace_short,
span_short,
);
} else {
tracing::warn!(
"← gRPC {} {} {}ms [trace={}|span={}]",
status.as_u16(),
path,
duration_ms,
trace_short,
span_short,
);
}
if let Some(level) = slow_request_level(duration_ms) {
tracing::warn!(
"{} {}ms gRPC {} [trace={}|span={}]",
level,
duration_ms,
path,
trace_short,
span_short,
);
}
}