Skip to main content

mockforge_registry_server/middleware/
mod.rs

1//! HTTP middleware
2
3pub mod api_token_auth;
4pub mod csrf;
5pub mod org_context;
6pub mod org_rate_limit;
7pub mod permission_check;
8// `permissions` module now lives in mockforge-registry-core. Re-exported
9// here so existing `crate::middleware::permissions::*` paths keep working.
10pub use mockforge_registry_core::permissions;
11pub mod rate_limit;
12pub mod request_id;
13// `scope_check` now lives in mockforge-registry-core. Re-exported so
14// existing `crate::middleware::scope_check::*` paths keep working.
15pub use mockforge_registry_core::scope_check;
16pub mod trusted_proxy;
17
18use axum::{
19    extract::{FromRequestParts, Request, State},
20    http::{request::Parts, HeaderMap, StatusCode},
21    middleware::Next,
22    response::{IntoResponse, Response},
23    Json,
24};
25use serde_json::json;
26use uuid::Uuid;
27
28use crate::auth::verify_token;
29use crate::middleware::api_token_auth::authenticate_api_token;
30use crate::AppState;
31
32pub use org_context::resolve_org_context;
33pub use rate_limit::rate_limit_middleware;
34pub use scope_check::{AuthType, ScopedAuth};
35
36/// JSON error response for authentication failures
37fn auth_error_response(message: &str, hint: &str) -> Response {
38    (
39        StatusCode::UNAUTHORIZED,
40        Json(json!({
41            "error": message,
42            "error_code": "AUTH_REQUIRED",
43            "status": 401,
44            "details": { "hint": hint }
45        })),
46    )
47        .into_response()
48}
49
50/// Extractor for authenticated user ID from JWT middleware
51#[derive(Debug, Clone)]
52pub struct AuthUser(pub Uuid);
53
54impl<S> FromRequestParts<S> for AuthUser
55where
56    S: Send + Sync,
57{
58    type Rejection = Response;
59
60    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
61        // Get user_id from request extensions (set by auth_middleware)
62        let user_id_str = parts.extensions.get::<String>().ok_or_else(|| {
63            auth_error_response(
64                "Authentication required",
65                "Include a valid Authorization header with your request",
66            )
67        })?;
68
69        // Parse UUID
70        let user_id = Uuid::parse_str(user_id_str).map_err(|_| {
71            (
72                StatusCode::INTERNAL_SERVER_ERROR,
73                Json(json!({
74                    "error": "Internal server error",
75                    "error_code": "INTERNAL_ERROR",
76                    "status": 500
77                })),
78            )
79                .into_response()
80        })?;
81
82        Ok(AuthUser(user_id))
83    }
84}
85
86/// Optional authenticated user extractor
87/// Returns None if no authentication is present, Some(user_id) if authenticated
88#[derive(Debug, Clone)]
89pub struct OptionalAuthUser(pub Option<Uuid>);
90
91impl<S> FromRequestParts<S> for OptionalAuthUser
92where
93    S: Send + Sync,
94{
95    type Rejection = StatusCode;
96
97    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
98        // Try to get user_id from request extensions (set by auth_middleware)
99        if let Some(user_id_str) = parts.extensions.get::<String>() {
100            if let Ok(user_id) = Uuid::parse_str(user_id_str) {
101                return Ok(OptionalAuthUser(Some(user_id)));
102            }
103        }
104        Ok(OptionalAuthUser(None))
105    }
106}
107
108/// Tiny percent-decoder for the `?token=` query path. We accept ASCII
109/// percent-encoded JWTs which is what browsers emit; anything malformed
110/// returns `None` and falls back to "no auth", letting the standard error
111/// path produce a 401.
112fn percent_decode_token(raw: &str) -> Option<String> {
113    let mut out = Vec::with_capacity(raw.len());
114    let bytes = raw.as_bytes();
115    let mut i = 0;
116    while i < bytes.len() {
117        match bytes[i] {
118            b'%' if i + 2 < bytes.len() => {
119                let hi = (bytes[i + 1] as char).to_digit(16)?;
120                let lo = (bytes[i + 2] as char).to_digit(16)?;
121                out.push(((hi as u8) << 4) | lo as u8);
122                i += 3;
123            }
124            b'+' => {
125                out.push(b' ');
126                i += 1;
127            }
128            other => {
129                out.push(other);
130                i += 1;
131            }
132        }
133    }
134    String::from_utf8(out).ok()
135}
136
137/// Extract and verify JWT token or API token.
138///
139/// Tokens are normally read from the `Authorization: Bearer <token>` header.
140/// Browsers cannot send custom headers on `EventSource` connections, so for
141/// SSE endpoints (and similar) we also accept the token in a `?token=` query
142/// string parameter. JWTs are short-lived, so the common "URL tokens leak to
143/// logs/referrer" risk is bounded — but we never recommend this path for
144/// long-lived API tokens. Header takes precedence when both are present.
145pub async fn auth_middleware(
146    State(state): State<AppState>,
147    headers: HeaderMap,
148    mut request: Request,
149    next: Next,
150) -> Result<Response, Response> {
151    let header_token = headers
152        .get("Authorization")
153        .and_then(|h| h.to_str().ok())
154        .and_then(|h| h.strip_prefix("Bearer "))
155        .map(|s| s.to_string());
156
157    // Pull `token=...` out of the query string without dragging in the `url`
158    // crate. SSE endpoints rely on this fallback because EventSource can't
159    // attach an Authorization header. Pure standard-library parse keeps the
160    // surface small.
161    let query_token = request.uri().query().and_then(|q| {
162        q.split('&').find_map(|pair| {
163            let (key, value) = pair.split_once('=')?;
164            if key == "token" {
165                percent_decode_token(value)
166            } else {
167                None
168            }
169        })
170    });
171
172    let owned_token = header_token.or(query_token).ok_or_else(|| {
173        auth_error_response(
174            "Authentication required",
175            "Include an Authorization: Bearer <token> header (or ?token= query for SSE).",
176        )
177    })?;
178
179    let token = owned_token.as_str();
180
181    // Check if this is an API token (starts with mfx_)
182    if token.starts_with("mfx_") {
183        match authenticate_api_token(&state, token).await.map_err(|_| {
184            auth_error_response(
185                "Authentication failed",
186                "API token validation error. Please try again.",
187            )
188        })? {
189            Some(auth_result) => {
190                request.extensions_mut().insert(auth_result.user_id.to_string());
191                request.extensions_mut().insert(AuthType::ApiToken);
192                request.extensions_mut().insert(auth_result.token);
193                return Ok(next.run(request).await);
194            }
195            None => {
196                return Err(auth_error_response(
197                    "Invalid API token",
198                    "The API token is invalid or has been revoked. Generate a new one at https://app.mockforge.dev/settings/tokens",
199                ));
200            }
201        }
202    }
203
204    // JWT authentication
205    let claims = verify_token(token, &state.config.jwt_secret).map_err(|_| {
206        auth_error_response(
207            "Invalid or expired token",
208            "Your session has expired. Please run 'mockforge cloud login' to re-authenticate.",
209        )
210    })?;
211
212    // Add user_id to request extensions
213    request.extensions_mut().insert(claims.sub.clone());
214    request.extensions_mut().insert(AuthType::Jwt);
215
216    Ok(next.run(request).await)
217}