Skip to main content

clawdb_server/http/
auth.rs

1use std::{sync::Arc, time::Instant};
2
3use axum::{
4    extract::{MatchedPath, Request, State},
5    http::{header, HeaderValue, StatusCode},
6    middleware::Next,
7    response::{IntoResponse, Response},
8    Json,
9};
10use serde::Serialize;
11
12use crate::state::{AppState, RequestId};
13
14#[derive(Clone)]
15pub struct AuthContext {
16    pub token: String,
17    pub session: clawdb::ClawDBSession,
18}
19
20#[derive(Serialize)]
21struct ErrorBody {
22    error: String,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    detail: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    request_id: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    component: Option<String>,
29}
30
31pub async fn request_id_middleware(mut request: Request, next: Next) -> Response {
32    let request_id = request
33        .headers()
34        .get("x-request-id")
35        .and_then(|value| value.to_str().ok())
36        .map(ToOwned::to_owned)
37        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
38    request
39        .extensions_mut()
40        .insert(RequestId(request_id.clone()));
41    let mut response = next.run(request).await;
42    if let Ok(value) = HeaderValue::from_str(&request_id) {
43        response.headers_mut().insert("x-request-id", value);
44    }
45    response
46}
47
48pub async fn metrics_middleware(
49    State(state): State<Arc<AppState>>,
50    request: Request,
51    next: Next,
52) -> Response {
53    let method = request.method().clone();
54    let path = request
55        .extensions()
56        .get::<MatchedPath>()
57        .map(|matched| matched.as_str().to_string())
58        .unwrap_or_else(|| request.uri().path().to_string());
59    let started = Instant::now();
60    let response = next.run(request).await;
61    state.metrics.observe_http(
62        method.as_str(),
63        &path,
64        response.status().as_u16(),
65        started.elapsed(),
66    );
67    if let Ok(count) = state.db.active_session_count().await {
68        state.metrics.set_active_sessions(count);
69    }
70    response
71}
72
73pub async fn auth_middleware(
74    State(state): State<Arc<AppState>>,
75    mut request: Request,
76    next: Next,
77) -> Response {
78    let token = match bearer_token(request.headers()) {
79        Some(token) => token,
80        None => return unauthorized(),
81    };
82
83    match state.db.validate_session(&token).await {
84        Ok(session) => {
85            request
86                .extensions_mut()
87                .insert(AuthContext { token, session });
88            next.run(request).await
89        }
90        Err(_) => unauthorized(),
91    }
92}
93
94pub async fn rate_limit_middleware(
95    State(state): State<Arc<AppState>>,
96    request: Request,
97    next: Next,
98) -> Response {
99    let Some(auth) = request.extensions().get::<AuthContext>().cloned() else {
100        return unauthorized();
101    };
102
103    let limiter = if request.method() == axum::http::Method::GET {
104        &state.http_read_limiter
105    } else {
106        &state.http_write_limiter
107    };
108
109    if let Err(not_until) = limiter.check_key(&auth.token) {
110        let retry_after = AppState::retry_after_seconds(&not_until);
111        let request_id = request
112            .extensions()
113            .get::<RequestId>()
114            .map(|value| value.0.clone());
115        let mut response = (
116            StatusCode::TOO_MANY_REQUESTS,
117            Json(ErrorBody {
118                error: "rate_limited".to_string(),
119                detail: None,
120                request_id,
121                component: None,
122            }),
123        )
124            .into_response();
125        if let Ok(value) = HeaderValue::from_str(&retry_after.to_string()) {
126            response.headers_mut().insert(header::RETRY_AFTER, value);
127        }
128        return response;
129    }
130
131    next.run(request).await
132}
133
134fn bearer_token(headers: &axum::http::HeaderMap) -> Option<String> {
135    headers
136        .get(header::AUTHORIZATION)
137        .and_then(|value| value.to_str().ok())
138        .and_then(|value| value.strip_prefix("Bearer "))
139        .map(ToOwned::to_owned)
140}
141
142fn unauthorized() -> Response {
143    (
144        StatusCode::UNAUTHORIZED,
145        Json(ErrorBody {
146            error: "unauthorized".to_string(),
147            detail: None,
148            request_id: None,
149            component: None,
150        }),
151    )
152        .into_response()
153}
154
155pub fn error_response(
156    status: StatusCode,
157    error: &str,
158    detail: Option<String>,
159    request_id: Option<String>,
160    component: Option<String>,
161) -> Response {
162    (
163        status,
164        Json(ErrorBody {
165            error: error.to_string(),
166            detail,
167            request_id,
168            component,
169        }),
170    )
171        .into_response()
172}