Skip to main content

mockforge_registry_server/middleware/
mod.rs

1//! HTTP middleware
2
3pub mod api_token_auth;
4pub mod csrf;
5pub mod org_context;
6pub mod org_rate_limit;
7pub mod past_due_writes;
8pub mod permission_check;
9// `permissions` module now lives in mockforge-registry-core. Re-exported
10// here so existing `crate::middleware::permissions::*` paths keep working.
11pub use mockforge_registry_core::permissions;
12pub mod rate_limit;
13pub mod request_id;
14// `scope_check` now lives in mockforge-registry-core. Re-exported so
15// existing `crate::middleware::scope_check::*` paths keep working.
16pub use mockforge_registry_core::scope_check;
17pub mod trusted_proxy;
18
19use axum::{
20    extract::{FromRequestParts, Request, State},
21    http::{request::Parts, HeaderMap, StatusCode},
22    middleware::Next,
23    response::{IntoResponse, Response},
24    Json,
25};
26use serde_json::json;
27use uuid::Uuid;
28
29use crate::auth::verify_token;
30use crate::middleware::api_token_auth::authenticate_api_token;
31use crate::AppState;
32
33pub use org_context::resolve_org_context;
34pub use rate_limit::rate_limit_middleware;
35pub use scope_check::{AuthType, ScopedAuth};
36
37/// JSON error response for authentication failures
38fn auth_error_response(message: &str, hint: &str) -> Response {
39    (
40        StatusCode::UNAUTHORIZED,
41        Json(json!({
42            "error": message,
43            "error_code": "AUTH_REQUIRED",
44            "status": 401,
45            "details": { "hint": hint }
46        })),
47    )
48        .into_response()
49}
50
51/// Extractor for authenticated user ID from JWT middleware
52#[derive(Debug, Clone)]
53pub struct AuthUser(pub Uuid);
54
55impl<S> FromRequestParts<S> for AuthUser
56where
57    S: Send + Sync,
58{
59    type Rejection = Response;
60
61    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
62        // Get user_id from request extensions (set by auth_middleware)
63        let user_id_str = parts.extensions.get::<String>().ok_or_else(|| {
64            auth_error_response(
65                "Authentication required",
66                "Include a valid Authorization header with your request",
67            )
68        })?;
69
70        // Parse UUID
71        let user_id = Uuid::parse_str(user_id_str).map_err(|_| {
72            (
73                StatusCode::INTERNAL_SERVER_ERROR,
74                Json(json!({
75                    "error": "Internal server error",
76                    "error_code": "INTERNAL_ERROR",
77                    "status": 500
78                })),
79            )
80                .into_response()
81        })?;
82
83        Ok(AuthUser(user_id))
84    }
85}
86
87/// Optional authenticated user extractor
88/// Returns None if no authentication is present, Some(user_id) if authenticated
89#[derive(Debug, Clone)]
90pub struct OptionalAuthUser(pub Option<Uuid>);
91
92impl<S> FromRequestParts<S> for OptionalAuthUser
93where
94    S: Send + Sync,
95{
96    type Rejection = StatusCode;
97
98    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
99        // Try to get user_id from request extensions (set by auth_middleware)
100        if let Some(user_id_str) = parts.extensions.get::<String>() {
101            if let Ok(user_id) = Uuid::parse_str(user_id_str) {
102                return Ok(OptionalAuthUser(Some(user_id)));
103            }
104        }
105        Ok(OptionalAuthUser(None))
106    }
107}
108
109/// Tiny percent-decoder for the `?token=` query path. We accept ASCII
110/// percent-encoded JWTs which is what browsers emit; anything malformed
111/// returns `None` and falls back to "no auth", letting the standard error
112/// path produce a 401.
113fn percent_decode_token(raw: &str) -> Option<String> {
114    let mut out = Vec::with_capacity(raw.len());
115    let bytes = raw.as_bytes();
116    let mut i = 0;
117    while i < bytes.len() {
118        match bytes[i] {
119            b'%' if i + 2 < bytes.len() => {
120                let hi = (bytes[i + 1] as char).to_digit(16)?;
121                let lo = (bytes[i + 2] as char).to_digit(16)?;
122                out.push(((hi as u8) << 4) | lo as u8);
123                i += 3;
124            }
125            b'+' => {
126                out.push(b' ');
127                i += 1;
128            }
129            other => {
130                out.push(other);
131                i += 1;
132            }
133        }
134    }
135    String::from_utf8(out).ok()
136}
137
138/// Extract and verify JWT token or API token.
139///
140/// Tokens are normally read from the `Authorization: Bearer <token>` header.
141/// Browsers cannot send custom headers on `EventSource` connections, so for
142/// SSE endpoints (and similar) we also accept the token in a `?token=` query
143/// string parameter. JWTs are short-lived, so the common "URL tokens leak to
144/// logs/referrer" risk is bounded — but we never recommend this path for
145/// long-lived API tokens. Header takes precedence when both are present.
146pub async fn auth_middleware(
147    State(state): State<AppState>,
148    headers: HeaderMap,
149    mut request: Request,
150    next: Next,
151) -> Result<Response, Response> {
152    let header_token = headers
153        .get("Authorization")
154        .and_then(|h| h.to_str().ok())
155        .and_then(|h| h.strip_prefix("Bearer "))
156        .map(|s| s.to_string());
157
158    // Pull `token=...` out of the query string without dragging in the `url`
159    // crate. SSE endpoints rely on this fallback because EventSource can't
160    // attach an Authorization header. Pure standard-library parse keeps the
161    // surface small.
162    let query_token = request.uri().query().and_then(|q| {
163        q.split('&').find_map(|pair| {
164            let (key, value) = pair.split_once('=')?;
165            if key == "token" {
166                percent_decode_token(value)
167            } else {
168                None
169            }
170        })
171    });
172
173    let owned_token = header_token.or(query_token).ok_or_else(|| {
174        auth_error_response(
175            "Authentication required",
176            "Include an Authorization: Bearer <token> header (or ?token= query for SSE).",
177        )
178    })?;
179
180    let token = owned_token.as_str();
181
182    // Check if this is an API token (starts with mfx_)
183    if token.starts_with("mfx_") {
184        match authenticate_api_token(&state, token).await.map_err(|_| {
185            auth_error_response(
186                "Authentication failed",
187                "API token validation error. Please try again.",
188            )
189        })? {
190            Some(auth_result) => {
191                request.extensions_mut().insert(auth_result.user_id.to_string());
192                request.extensions_mut().insert(AuthType::ApiToken);
193                request.extensions_mut().insert(auth_result.token);
194                return Ok(next.run(request).await);
195            }
196            None => {
197                return Err(auth_error_response(
198                    "Invalid API token",
199                    "The API token is invalid or has been revoked. Generate a new one at https://app.mockforge.dev/settings/tokens",
200                ));
201            }
202        }
203    }
204
205    // JWT authentication
206    let claims = verify_token(token, &state.config.jwt_secret).map_err(|_| {
207        auth_error_response(
208            "Invalid or expired token",
209            "Your session has expired. Please run 'mockforge cloud login' to re-authenticate.",
210        )
211    })?;
212
213    // Add user_id to request extensions
214    request.extensions_mut().insert(claims.sub.clone());
215    request.extensions_mut().insert(AuthType::Jwt);
216
217    Ok(next.run(request).await)
218}