1use axum::{
2 Json, Router,
3 extract::Request,
4 middleware::{self, Next},
5 response::{IntoResponse, Response},
6 routing::get,
7};
8use serde_json::json;
9use tower::ServiceBuilder;
10use tower_http::{
11 compression::CompressionLayer,
12 cors::{Any, CorsLayer},
13 request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
14 trace::TraceLayer,
15};
16use tracing::Span;
17use uuid::Uuid;
18
19#[derive(Debug, Clone, Default)]
24pub struct MiddlewareConfig {
25 pub allowed_origins: String,
29}
30
31pub fn stack<S>(router: Router<S>, cfg: &MiddlewareConfig) -> Router<S>
35where
36 S: Clone + Send + Sync + 'static,
37{
38 router.route("/health", get(health)).layer(
39 ServiceBuilder::new()
40 .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
41 .layer(PropagateRequestIdLayer::x_request_id())
42 .layer(
43 TraceLayer::new_for_http()
44 .make_span_with(|req: &Request<_>| {
45 let request_id = req
46 .headers()
47 .get("x-request-id")
48 .and_then(|v| v.to_str().ok())
49 .unwrap_or("-");
50 tracing::info_span!(
51 "request",
52 method = %req.method(),
53 uri = %req.uri(),
54 request_id,
55 )
56 })
57 .on_response(
58 |resp: &Response<_>, latency: std::time::Duration, _span: &Span| {
59 tracing::info!(
60 status = resp.status().as_u16(),
61 latency_ms = latency.as_millis(),
62 "response"
63 );
64 },
65 ),
66 )
67 .layer(CompressionLayer::new())
68 .layer(
69 CorsLayer::new()
70 .allow_methods(Any)
71 .allow_headers(Any)
72 .allow_origin(cors_origin(&cfg.allowed_origins)),
73 )
74 .layer(middleware::from_fn(error_formatter)),
75 )
76}
77
78pub async fn serve(app: Router, port: u16) {
85 let addr = format!("0.0.0.0:{port}");
86 let listener = tokio::net::TcpListener::bind(&addr)
87 .await
88 .expect("bind failed");
89 tracing::info!("listening on {addr}");
90
91 axum::serve(listener, app)
92 .with_graceful_shutdown(shutdown_signal())
93 .await
94 .expect("server error");
95}
96
97async fn health() -> impl IntoResponse {
98 Json(json!({ "status": "ok" }))
99}
100
101async fn error_formatter(req: Request, next: Next) -> Response {
103 let resp = next.run(req).await;
104 if resp.status().is_server_error() {
105 let status = resp.status();
106 return (
107 status,
108 Json(json!({
109 "error": status.canonical_reason().unwrap_or("internal error"),
110 "request_id": Uuid::new_v4().to_string(),
111 })),
112 )
113 .into_response();
114 }
115 resp
116}
117
118fn cors_origin(allowed_origins: &str) -> tower_http::cors::AllowOrigin {
119 if allowed_origins.is_empty() {
120 return tower_http::cors::AllowOrigin::any();
121 }
122 let parsed: Vec<_> = allowed_origins
123 .split(',')
124 .filter_map(|o| o.trim().parse::<axum::http::HeaderValue>().ok())
125 .collect();
126 tower_http::cors::AllowOrigin::list(parsed)
127}
128
129async fn shutdown_signal() {
130 let ctrl_c = async {
131 tokio::signal::ctrl_c()
132 .await
133 .expect("ctrl-c handler failed");
134 };
135 #[cfg(unix)]
136 let terminate = async {
137 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
138 .expect("SIGTERM handler failed")
139 .recv()
140 .await;
141 };
142 #[cfg(not(unix))]
143 let terminate = std::future::pending::<()>();
144
145 tokio::select! {
146 _ = ctrl_c => tracing::info!("ctrl-c received, shutting down"),
147 _ = terminate => tracing::info!("SIGTERM received, shutting down"),
148 }
149}