clawdb_server/http/
auth.rs1use std::{sync::Arc, time::Instant};
2
3use axum::{
4 extract::{MatchedPath, Request, State},
5 http::{header, HeaderValue, StatusCode},
6 middleware::Next,
7 response::{IntoResponse, Response},
8 Json,
9};
10use serde::Serialize;
11
12use crate::state::{AppState, RequestId};
13
14#[derive(Clone)]
15pub struct AuthContext {
16 pub token: String,
17 pub session: clawdb::ClawDBSession,
18}
19
20#[derive(Serialize)]
21struct ErrorBody {
22 error: String,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 detail: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 request_id: Option<String>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 component: Option<String>,
29}
30
31pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
32 let request_id = request
33 .headers()
34 .get("x-request-id")
35 .and_then(|value| value.to_str().ok())
36 .map(ToOwned::to_owned)
37 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
38 request
39 .extensions_mut()
40 .insert(RequestId(request_id.clone()));
41 let mut response = next.run(request).await;
42 if let Ok(value) = HeaderValue::from_str(&request_id) {
43 response.headers_mut().insert("x-request-id", value);
44 }
45 response
46}
47
48pub async fn metrics_middleware(
49 State(state): State<Arc<AppState>>,
50 request: Request,
51 next: Next,
52) -> Response {
53 let method = request.method().clone();
54 let path = request
55 .extensions()
56 .get::<MatchedPath>()
57 .map(|matched| matched.as_str().to_string())
58 .unwrap_or_else(|| request.uri().path().to_string());
59 let started = Instant::now();
60 let response = next.run(request).await;
61 state.metrics.observe_http(
62 method.as_str(),
63 &path,
64 response.status().as_u16(),
65 started.elapsed(),
66 );
67 if let Ok(count) = state.db.active_session_count().await {
68 state.metrics.set_active_sessions(count);
69 }
70 response
71}
72
73pub async fn auth_middleware(
74 State(state): State<Arc<AppState>>,
75 mut request: Request,
76 next: Next,
77) -> Response {
78 let token = match bearer_token(request.headers()) {
79 Some(token) => token,
80 None => return unauthorized(),
81 };
82
83 match state.db.validate_session(&token).await {
84 Ok(session) => {
85 request
86 .extensions_mut()
87 .insert(AuthContext { token, session });
88 next.run(request).await
89 }
90 Err(_) => unauthorized(),
91 }
92}
93
94pub async fn rate_limit_middleware(
95 State(state): State<Arc<AppState>>,
96 request: Request,
97 next: Next,
98) -> Response {
99 let Some(auth) = request.extensions().get::<AuthContext>().cloned() else {
100 return unauthorized();
101 };
102
103 let limiter = if request.method() == axum::http::Method::GET {
104 &state.http_read_limiter
105 } else {
106 &state.http_write_limiter
107 };
108
109 if let Err(not_until) = limiter.check_key(&auth.token) {
110 let retry_after = AppState::retry_after_seconds(¬_until);
111 let request_id = request
112 .extensions()
113 .get::<RequestId>()
114 .map(|value| value.0.clone());
115 let mut response = (
116 StatusCode::TOO_MANY_REQUESTS,
117 Json(ErrorBody {
118 error: "rate_limited".to_string(),
119 detail: None,
120 request_id,
121 component: None,
122 }),
123 )
124 .into_response();
125 if let Ok(value) = HeaderValue::from_str(&retry_after.to_string()) {
126 response.headers_mut().insert(header::RETRY_AFTER, value);
127 }
128 return response;
129 }
130
131 next.run(request).await
132}
133
134fn bearer_token(headers: &axum::http::HeaderMap) -> Option<String> {
135 headers
136 .get(header::AUTHORIZATION)
137 .and_then(|value| value.to_str().ok())
138 .and_then(|value| value.strip_prefix("Bearer "))
139 .map(ToOwned::to_owned)
140}
141
142fn unauthorized() -> Response {
143 (
144 StatusCode::UNAUTHORIZED,
145 Json(ErrorBody {
146 error: "unauthorized".to_string(),
147 detail: None,
148 request_id: None,
149 component: None,
150 }),
151 )
152 .into_response()
153}
154
155pub fn error_response(
156 status: StatusCode,
157 error: &str,
158 detail: Option<String>,
159 request_id: Option<String>,
160 component: Option<String>,
161) -> Response {
162 (
163 status,
164 Json(ErrorBody {
165 error: error.to_string(),
166 detail,
167 request_id,
168 component,
169 }),
170 )
171 .into_response()
172}