fbc_starter/http/
middleware.rs1use axum::http::HeaderName;
2use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
3use tower_http::{
4 cors::{Any, CorsLayer},
5 trace::TraceLayer,
6};
7
8use crate::config::Config;
9
10pub fn create_cors_layer(config: &Config) -> CorsLayer {
12 let allow_credentials = config.cors.allow_credentials;
13
14 let mut cors = CorsLayer::new().allow_methods(
18 config
19 .cors
20 .allowed_methods
21 .iter()
22 .map(|m| m.parse().unwrap())
23 .collect::<Vec<_>>(),
24 );
25
26 if config.cors.allowed_headers.contains(&"*".to_string()) {
28 if allow_credentials {
29 cors = cors.allow_headers([
31 HeaderName::from_static("content-type"),
32 HeaderName::from_static("authorization"),
33 HeaderName::from_static("x-requested-with"),
34 HeaderName::from_static("accept"),
35 HeaderName::from_static("origin"),
36 ]);
37 } else {
38 cors = cors.allow_headers(Any);
39 }
40 } else {
41 let headers: Result<Vec<_>, _> = config
43 .cors
44 .allowed_headers
45 .iter()
46 .map(|h| HeaderName::from_bytes(h.as_bytes()))
47 .collect();
48 if let Ok(headers) = headers {
49 cors = cors.allow_headers(headers);
50 }
51 }
52
53 if config.cors.allowed_origins.contains(&"*".to_string()) {
56 if allow_credentials {
57 tracing::warn!(
58 "CORS: allow_credentials=true 与 allowed_origins=* 不兼容,已自动禁用 allow_credentials"
59 );
60 }
61 cors = cors.allow_origin(Any).allow_credentials(false);
62 } else {
63 let origins: Result<Vec<_>, _> = config
64 .cors
65 .allowed_origins
66 .iter()
67 .map(|o| o.parse())
68 .collect();
69 if let Ok(origins) = origins {
70 cors = cors.allow_origin(origins.into_iter().collect::<Vec<_>>());
71 }
72 cors = cors.allow_credentials(allow_credentials);
73 }
74
75 cors
76}
77
78pub fn create_trace_layer(
80) -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>>
81{
82 TraceLayer::new_for_http()
83}
84
85pub async fn request_logger_middleware(
87 request: Request,
88 next: Next,
89) -> Result<Response, StatusCode> {
90 let method = request.method().clone();
91 let uri = request.uri().clone();
92 let start = std::time::Instant::now();
93
94 tracing::info!("{} {}", method, uri);
95
96 let response = next.run(request).await;
97 let duration = start.elapsed();
98
99 tracing::info!(
100 "{} {} - {} - {:?}",
101 method,
102 uri,
103 response.status(),
104 duration
105 );
106
107 Ok(response)
108}