Skip to main content

forge_runtime/gateway/
oauth.rs

1//! OAuth 2.1 Authorization Server for MCP.
2//!
3//! Implements Authorization Code + PKCE flow so MCP clients like Claude Code
4//! can auto-authenticate via browser login.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use axum::Json;
11use axum::extract::{Query, State};
12use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
13use axum::response::{Html, IntoResponse, Redirect, Response};
14use chrono::Utc;
15use forge_core::auth::Claims;
16use forge_core::oauth::{self, validate_redirect_uri};
17use serde::{Deserialize, Serialize};
18use tokio::sync::RwLock;
19use uuid::Uuid;
20
21use super::auth::AuthMiddleware;
22
23const AUTHORIZE_PAGE: &str = include_str!("oauth_authorize.html");
24const AUTH_CODE_TTL_SECS: i64 = 60;
25const MAX_REGISTERED_CLIENTS: i64 = 1000;
26const CHALLENGE_METHOD_S256: &str = "S256";
27const MCP_AUDIENCE: &str = "forge:mcp";
28
29// Rate limiting constants
30const REGISTER_RATE_LIMIT: u32 = 10; // per minute per IP
31const LOGIN_FAIL_RATE_LIMIT: u32 = 5; // per minute per IP
32const RATE_WINDOW_SECS: u64 = 60;
33const RATE_CLEANUP_THRESHOLD: usize = 100;
34
35/// In-memory rate limiter for OAuth endpoints.
36#[derive(Clone, Default)]
37struct OAuthRateLimiter {
38    buckets: Arc<RwLock<HashMap<String, (u32, Instant)>>>,
39}
40
41impl OAuthRateLimiter {
42    async fn check(&self, key: &str, limit: u32) -> bool {
43        let mut buckets = self.buckets.write().await;
44        let now = Instant::now();
45        let window = Duration::from_secs(RATE_WINDOW_SECS);
46
47        // Purge stale entries periodically to prevent unbounded growth
48        if buckets.len() > RATE_CLEANUP_THRESHOLD {
49            buckets.retain(|_, (_, ts)| now.duration_since(*ts) <= window);
50        }
51
52        let entry = buckets.entry(key.to_string()).or_insert((0, now));
53        if now.duration_since(entry.1) > window {
54            *entry = (1, now);
55            return true;
56        }
57        if entry.0 >= limit {
58            return false;
59        }
60        entry.0 += 1;
61        true
62    }
63}
64
65/// Shared state for OAuth endpoints.
66#[derive(Clone)]
67pub struct OAuthState {
68    pool: sqlx::PgPool,
69    auth_middleware: Arc<AuthMiddleware>,
70    token_issuer: Arc<dyn forge_core::TokenIssuer>,
71    access_token_ttl_secs: i64,
72    refresh_token_ttl_days: i64,
73    auth_is_hmac: bool,
74    project_name: String,
75    jwt_secret: String,
76    rate_limiter: OAuthRateLimiter,
77    /// CSRF tokens: token -> expiry
78    csrf_tokens: Arc<RwLock<HashMap<String, Instant>>>,
79}
80
81impl OAuthState {
82    #[allow(clippy::too_many_arguments)]
83    pub fn new(
84        pool: sqlx::PgPool,
85        auth_middleware: Arc<AuthMiddleware>,
86        token_issuer: Arc<dyn forge_core::TokenIssuer>,
87        access_token_ttl_secs: i64,
88        refresh_token_ttl_days: i64,
89        auth_is_hmac: bool,
90        project_name: String,
91        jwt_secret: String,
92    ) -> Self {
93        Self {
94            pool,
95            auth_middleware,
96            token_issuer,
97            access_token_ttl_secs,
98            refresh_token_ttl_days,
99            auth_is_hmac,
100            project_name,
101            jwt_secret,
102            rate_limiter: OAuthRateLimiter::default(),
103            csrf_tokens: Arc::new(RwLock::new(HashMap::new())),
104        }
105    }
106
107    async fn store_csrf(&self, token: &str) {
108        let mut tokens = self.csrf_tokens.write().await;
109        let now = Instant::now();
110        let expiry = now + Duration::from_secs(600); // 10 min
111        tokens.insert(token.to_string(), expiry);
112        // Purge expired tokens periodically, not on every insert
113        if tokens.len() > RATE_CLEANUP_THRESHOLD {
114            tokens.retain(|_, exp| *exp > now);
115        }
116    }
117
118    async fn validate_csrf(&self, token: &str) -> bool {
119        let mut tokens = self.csrf_tokens.write().await;
120        if let Some(expiry) = tokens.remove(token) {
121            expiry > Instant::now()
122        } else {
123            false
124        }
125    }
126}
127
128// ── Well-known metadata endpoints ──────────────────────────────────────
129
130#[derive(Serialize)]
131pub struct AuthorizationServerMetadata {
132    issuer: String,
133    authorization_endpoint: String,
134    token_endpoint: String,
135    registration_endpoint: String,
136    response_types_supported: Vec<String>,
137    grant_types_supported: Vec<String>,
138    code_challenge_methods_supported: Vec<String>,
139    token_endpoint_auth_methods_supported: Vec<String>,
140}
141
142pub async fn well_known_oauth_metadata(
143    headers: HeaderMap,
144    State(_state): State<Arc<OAuthState>>,
145) -> Json<AuthorizationServerMetadata> {
146    let base = base_url_from_headers(&headers);
147    Json(AuthorizationServerMetadata {
148        issuer: base.clone(),
149        authorization_endpoint: format!("{base}/_api/oauth/authorize"),
150        token_endpoint: format!("{base}/_api/oauth/token"),
151        registration_endpoint: format!("{base}/_api/oauth/register"),
152        response_types_supported: vec!["code".into()],
153        grant_types_supported: vec!["authorization_code".into(), "refresh_token".into()],
154        code_challenge_methods_supported: vec![CHALLENGE_METHOD_S256.into()],
155        token_endpoint_auth_methods_supported: vec!["none".into()],
156    })
157}
158
159#[derive(Serialize)]
160pub struct ProtectedResourceMetadata {
161    resource: String,
162    authorization_servers: Vec<String>,
163}
164
165pub async fn well_known_resource_metadata(
166    headers: HeaderMap,
167    State(_state): State<Arc<OAuthState>>,
168) -> Json<ProtectedResourceMetadata> {
169    let base = base_url_from_headers(&headers);
170    Json(ProtectedResourceMetadata {
171        resource: base.clone(),
172        authorization_servers: vec![base],
173    })
174}
175
176// ── Dynamic client registration ────────────────────────────────────────
177
178#[derive(Deserialize)]
179pub struct RegisterRequest {
180    pub client_name: Option<String>,
181    pub redirect_uris: Vec<String>,
182    #[serde(default)]
183    pub grant_types: Vec<String>,
184    #[serde(default)]
185    pub token_endpoint_auth_method: Option<String>,
186}
187
188#[derive(Serialize)]
189pub struct RegisterResponse {
190    pub client_id: String,
191    pub client_name: Option<String>,
192    pub redirect_uris: Vec<String>,
193    pub grant_types: Vec<String>,
194    pub token_endpoint_auth_method: String,
195}
196
197pub async fn oauth_register(
198    headers: HeaderMap,
199    State(state): State<Arc<OAuthState>>,
200    Json(req): Json<RegisterRequest>,
201) -> Response {
202    let ip = client_ip(&headers);
203    let rate_key = format!("oauth_register:{ip}");
204    if !state
205        .rate_limiter
206        .check(&rate_key, REGISTER_RATE_LIMIT)
207        .await
208    {
209        return (
210            StatusCode::TOO_MANY_REQUESTS,
211            Json(serde_json::json!({
212                "error": "too_many_requests",
213                "error_description": "Rate limit exceeded for client registration"
214            })),
215        )
216            .into_response();
217    }
218
219    // Check client cap
220    let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM forge_oauth_clients")
221        .fetch_one(&state.pool)
222        .await
223        .unwrap_or(0);
224    if count >= MAX_REGISTERED_CLIENTS {
225        return (
226            StatusCode::BAD_REQUEST,
227            Json(serde_json::json!({
228                "error": "too_many_clients",
229                "error_description": "Maximum number of registered clients reached"
230            })),
231        )
232            .into_response();
233    }
234
235    if req.redirect_uris.is_empty() {
236        return (
237            StatusCode::BAD_REQUEST,
238            Json(serde_json::json!({
239                "error": "invalid_client_metadata",
240                "error_description": "redirect_uris is required"
241            })),
242        )
243            .into_response();
244    }
245
246    let client_id = Uuid::new_v4().to_string();
247    let auth_method = req.token_endpoint_auth_method.as_deref().unwrap_or("none");
248
249    let result = sqlx::query(
250        "INSERT INTO forge_oauth_clients (client_id, client_name, redirect_uris, token_endpoint_auth_method) \
251         VALUES ($1, $2, $3, $4)"
252    )
253    .bind(&client_id)
254    .bind(&req.client_name)
255    .bind(&req.redirect_uris)
256    .bind(auth_method)
257    .execute(&state.pool)
258    .await;
259
260    if let Err(e) = result {
261        tracing::error!("Failed to register OAuth client: {e}");
262        return (
263            StatusCode::INTERNAL_SERVER_ERROR,
264            Json(serde_json::json!({
265                "error": "server_error",
266                "error_description": "Failed to register client"
267            })),
268        )
269            .into_response();
270    }
271
272    let grant_types = if req.grant_types.is_empty() {
273        vec!["authorization_code".into()]
274    } else {
275        req.grant_types
276    };
277
278    (
279        StatusCode::CREATED,
280        Json(RegisterResponse {
281            client_id,
282            client_name: req.client_name,
283            redirect_uris: req.redirect_uris,
284            grant_types,
285            token_endpoint_auth_method: auth_method.to_string(),
286        }),
287    )
288        .into_response()
289}
290
291// ── Authorization endpoint ─────────────────────────────────────────────
292
293#[derive(Deserialize)]
294pub struct AuthorizeQuery {
295    pub client_id: String,
296    pub redirect_uri: String,
297    pub code_challenge: String,
298    #[serde(default = "default_s256")]
299    pub code_challenge_method: String,
300    pub state: Option<String>,
301    pub scope: Option<String>,
302    pub response_type: Option<String>,
303}
304
305fn default_s256() -> String {
306    CHALLENGE_METHOD_S256.into()
307}
308
309pub async fn oauth_authorize_get(
310    headers: HeaderMap,
311    Query(params): Query<AuthorizeQuery>,
312    State(state): State<Arc<OAuthState>>,
313) -> Response {
314    // Validate client_id
315    let client = sqlx::query_as::<_, (String, Option<String>, Vec<String>)>(
316        "SELECT client_id, client_name, redirect_uris FROM forge_oauth_clients WHERE client_id = $1"
317    )
318    .bind(&params.client_id)
319    .fetch_optional(&state.pool)
320    .await;
321
322    let (_, client_name, redirect_uris) = match client {
323        Ok(Some(c)) => c,
324        Ok(None) => {
325            return (
326                StatusCode::BAD_REQUEST,
327                Json(serde_json::json!({
328                    "error": "invalid_client",
329                    "error_description": "Unknown client_id"
330                })),
331            )
332                .into_response();
333        }
334        Err(e) => {
335            tracing::error!("OAuth client lookup failed: {e}");
336            return (
337                StatusCode::INTERNAL_SERVER_ERROR,
338                Json(serde_json::json!({
339                    "error": "server_error"
340                })),
341            )
342                .into_response();
343        }
344    };
345
346    // Validate redirect_uri (exact match, T2)
347    if !validate_redirect_uri(&params.redirect_uri, &redirect_uris) {
348        return (
349            StatusCode::BAD_REQUEST,
350            Json(serde_json::json!({
351                "error": "invalid_redirect_uri",
352                "error_description": "redirect_uri does not match any registered URI"
353            })),
354        )
355            .into_response();
356    }
357
358    if params.code_challenge_method != CHALLENGE_METHOD_S256 {
359        return (
360            StatusCode::BAD_REQUEST,
361            Json(serde_json::json!({
362                "error": "invalid_request",
363                "error_description": "Only S256 code_challenge_method is supported"
364            })),
365        )
366            .into_response();
367    }
368
369    // Check for existing session cookie (set by auth middleware on API calls)
370    let session_subject = extract_cookie(&headers, "forge_session")
371        .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
372    let has_session = session_subject.is_some();
373
374    // Generate CSRF token (T5)
375    let csrf_token = oauth::generate_random_token();
376    state.store_csrf(&csrf_token).await;
377
378    let auth_mode = if has_session {
379        "session" // user is known from cookie, show consent directly
380    } else if state.auth_is_hmac {
381        "hmac" // show email/password form
382    } else {
383        "external" // show "log in to your app first"
384    };
385    let display_name = client_name.as_deref().unwrap_or(&params.client_id);
386
387    let html = AUTHORIZE_PAGE
388        .replace("{{app_name}}", &html_escape(&state.project_name))
389        .replace("{{client_name}}", &html_escape(display_name))
390        .replace("{{csrf_token}}", &csrf_token)
391        .replace("{{client_id}}", &html_escape(&params.client_id))
392        .replace("{{redirect_uri}}", &html_escape(&params.redirect_uri))
393        .replace("{{code_challenge}}", &html_escape(&params.code_challenge))
394        .replace(
395            "{{code_challenge_method}}",
396            &html_escape(&params.code_challenge_method),
397        )
398        .replace(
399            "{{state}}",
400            &html_escape(params.state.as_deref().unwrap_or("")),
401        )
402        .replace(
403            "{{scope}}",
404            &html_escape(params.scope.as_deref().unwrap_or("")),
405        )
406        .replace("{{auth_mode}}", auth_mode)
407        .replace("{{authorize_url}}", "/_api/oauth/authorize")
408        .replace("{{error_message}}", "");
409
410    let mut response = (StatusCode::OK, Html(html)).into_response();
411    // T17: clickjacking protection
412    response
413        .headers_mut()
414        .insert("X-Frame-Options", HeaderValue::from_static("DENY"));
415    response.headers_mut().insert(
416        "Content-Security-Policy",
417        HeaderValue::from_static("frame-ancestors 'none'"),
418    );
419    // Set CSRF cookie (T1, T5)
420    let cookie = format!(
421        "forge_oauth_csrf={csrf_token}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=600"
422    );
423    if let Ok(cookie_val) = HeaderValue::from_str(&cookie) {
424        response
425            .headers_mut()
426            .insert(header::SET_COOKIE, cookie_val);
427    }
428    response
429}
430
431#[derive(Deserialize)]
432pub struct AuthorizeForm {
433    pub csrf_token: String,
434    pub client_id: String,
435    pub redirect_uri: String,
436    pub code_challenge: String,
437    pub code_challenge_method: String,
438    pub state: Option<String>,
439    pub scope: Option<String>,
440    pub response_type: Option<String>,
441    // Consent flow: existing token from localStorage
442    pub token: Option<String>,
443    // Login flow: email/password
444    pub email: Option<String>,
445    pub password: Option<String>,
446}
447
448pub async fn oauth_authorize_post(
449    headers: HeaderMap,
450    State(state): State<Arc<OAuthState>>,
451    axum::Form(form): axum::Form<AuthorizeForm>,
452) -> Response {
453    // Validate CSRF (T5): check both cookie and form value
454    let csrf_from_cookie = extract_cookie(&headers, "forge_oauth_csrf");
455    let csrf_valid = if let Some(cookie_csrf) = csrf_from_cookie {
456        cookie_csrf == form.csrf_token && state.validate_csrf(&form.csrf_token).await
457    } else {
458        false
459    };
460    if !csrf_valid {
461        return (
462            StatusCode::FORBIDDEN,
463            Json(serde_json::json!({
464                "error": "csrf_validation_failed",
465                "error_description": "Invalid or expired CSRF token. Please try again."
466            })),
467        )
468            .into_response();
469    }
470
471    // Rate limit login failures (T7)
472    let ip = client_ip(&headers);
473    let rate_key = format!("oauth_login:{ip}");
474
475    // Validate client and redirect_uri again (form could be tampered)
476    let client = sqlx::query_as::<_, (Vec<String>,)>(
477        "SELECT redirect_uris FROM forge_oauth_clients WHERE client_id = $1",
478    )
479    .bind(&form.client_id)
480    .fetch_optional(&state.pool)
481    .await;
482
483    let redirect_uris = match client {
484        Ok(Some((uris,))) => uris,
485        _ => {
486            return (
487                StatusCode::BAD_REQUEST,
488                Json(serde_json::json!({
489                    "error": "invalid_client"
490                })),
491            )
492                .into_response();
493        }
494    };
495
496    if !validate_redirect_uri(&form.redirect_uri, &redirect_uris) {
497        return (
498            StatusCode::BAD_REQUEST,
499            Json(serde_json::json!({
500                "error": "invalid_redirect_uri"
501            })),
502        )
503            .into_response();
504    }
505
506    // Authenticate user: try session cookie first, then token, then email/password
507    let user_id: Uuid;
508
509    let session_subject = extract_cookie(&headers, "forge_session")
510        .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
511
512    if let Some(subject) = session_subject {
513        // Session cookie flow: user identified by signed cookie from previous API calls.
514        // Subject may be a UUID (HMAC auth) or an external provider ID (Firebase, Clerk).
515        user_id = subject.parse::<Uuid>().unwrap_or_else(|_| {
516            // Non-UUID subject (Firebase UID, etc.): deterministic UUID from subject hash.
517            use sha2::Digest;
518            let hash: [u8; 32] = sha2::Sha256::digest(subject.as_bytes()).into();
519            let mut bytes = [0u8; 16];
520            bytes.copy_from_slice(&hash[..16]);
521            Uuid::from_bytes(bytes)
522        });
523    } else if let Some(token) = &form.token {
524        // Consent flow: validate existing JWT
525        match state.auth_middleware.validate_token_async(token).await {
526            Ok(claims) => {
527                user_id = claims
528                    .user_id()
529                    .ok_or(())
530                    .map_err(|_| ())
531                    .unwrap_or_default();
532                if user_id.is_nil() {
533                    return authorize_error_redirect(
534                        &form.redirect_uri,
535                        form.state.as_deref(),
536                        "access_denied",
537                        "Invalid user identity in token",
538                    );
539                }
540            }
541            Err(_) => {
542                return authorize_error_redirect(
543                    &form.redirect_uri,
544                    form.state.as_deref(),
545                    "access_denied",
546                    "Invalid or expired token. Please log in again.",
547                );
548            }
549        }
550    } else if let (Some(email), Some(password)) = (&form.email, &form.password) {
551        // Login flow (HMAC mode only)
552        if !state.auth_is_hmac {
553            return authorize_error_redirect(
554                &form.redirect_uri,
555                form.state.as_deref(),
556                "access_denied",
557                "Direct login not supported with external auth provider",
558            );
559        }
560
561        if !state
562            .rate_limiter
563            .check(&rate_key, LOGIN_FAIL_RATE_LIMIT)
564            .await
565        {
566            return authorize_error_redirect(
567                &form.redirect_uri,
568                form.state.as_deref(),
569                "access_denied",
570                "Too many login attempts. Please try again later.",
571            );
572        }
573
574        // Query users table by convention
575        let row = sqlx::query_as::<_, (Uuid, Option<String>, Option<String>)>(
576            "SELECT id, password_hash, role::TEXT FROM users WHERE email = $1",
577        )
578        .bind(email)
579        .fetch_optional(&state.pool)
580        .await;
581
582        match row {
583            Ok(Some((uid, Some(hash), _role))) => match bcrypt::verify(password, &hash) {
584                Ok(true) => {
585                    user_id = uid;
586                }
587                _ => {
588                    return authorize_error_redirect(
589                        &form.redirect_uri,
590                        form.state.as_deref(),
591                        "access_denied",
592                        "Invalid email or password",
593                    );
594                }
595            },
596            _ => {
597                return authorize_error_redirect(
598                    &form.redirect_uri,
599                    form.state.as_deref(),
600                    "access_denied",
601                    "Invalid email or password",
602                );
603            }
604        }
605    } else {
606        return (
607            StatusCode::BAD_REQUEST,
608            Json(serde_json::json!({
609                "error": "invalid_request",
610                "error_description": "Must provide either a token or email/password"
611            })),
612        )
613            .into_response();
614    }
615
616    // Generate authorization code
617    let code = oauth::generate_random_token();
618    let expires_at = Utc::now() + chrono::Duration::seconds(AUTH_CODE_TTL_SECS);
619    let scopes: Vec<String> = form
620        .scope
621        .as_deref()
622        .map(|s| s.split_whitespace().map(String::from).collect())
623        .unwrap_or_default();
624
625    let result = sqlx::query(
626        "INSERT INTO forge_oauth_codes \
627         (code, client_id, user_id, redirect_uri, code_challenge, code_challenge_method, scopes, expires_at) \
628         VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
629    )
630    .bind(&code)
631    .bind(&form.client_id)
632    .bind(user_id)
633    .bind(&form.redirect_uri)
634    .bind(&form.code_challenge)
635    .bind(&form.code_challenge_method)
636    .bind(&scopes)
637    .bind(expires_at)
638    .execute(&state.pool)
639    .await;
640
641    if let Err(e) = result {
642        tracing::error!("Failed to store authorization code: {e}");
643        return authorize_error_redirect(
644            &form.redirect_uri,
645            form.state.as_deref(),
646            "server_error",
647            "Failed to generate authorization code",
648        );
649    }
650
651    // Redirect with code (T18: Referrer-Policy)
652    let mut redirect_url = format!("{}?code={}", form.redirect_uri, urlencoding(&code));
653    if let Some(st) = &form.state {
654        redirect_url.push_str(&format!("&state={}", urlencoding(st)));
655    }
656
657    let mut response = Redirect::to(&redirect_url).into_response();
658    response
659        .headers_mut()
660        .insert("Referrer-Policy", HeaderValue::from_static("no-referrer"));
661
662    // Set session cookie so the next authorize visit shows consent directly
663    // instead of the login form. This is same-origin (backend serves both
664    // the authorize page and this POST), so the cookie sticks.
665    let cookie_value = super::auth::sign_session_cookie(&user_id.to_string(), &state.jwt_secret);
666    let secure_flag = if is_https(&headers) { "; Secure" } else { "" };
667    let session_cookie = format!(
668        "forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=86400{secure_flag}"
669    );
670    if let Ok(val) = HeaderValue::from_str(&session_cookie) {
671        response.headers_mut().append(header::SET_COOKIE, val);
672    }
673
674    response
675}
676
677// ── Token endpoint ─────────────────────────────────────────────────────
678
679#[derive(Deserialize)]
680pub struct TokenRequest {
681    pub grant_type: String,
682    pub code: Option<String>,
683    pub redirect_uri: Option<String>,
684    pub code_verifier: Option<String>,
685    pub client_id: Option<String>,
686    pub refresh_token: Option<String>,
687}
688
689#[derive(Serialize)]
690pub struct TokenResponse {
691    pub access_token: String,
692    pub token_type: String,
693    pub expires_in: i64,
694    pub refresh_token: String,
695}
696
697/// Token endpoint accepts both `application/json` and
698/// `application/x-www-form-urlencoded` (OAuth 2.1 standard).
699pub async fn oauth_token(
700    State(state): State<Arc<OAuthState>>,
701    headers: HeaderMap,
702    body: axum::body::Bytes,
703) -> Response {
704    let content_type = headers
705        .get(header::CONTENT_TYPE)
706        .and_then(|v| v.to_str().ok())
707        .unwrap_or("");
708
709    let req: TokenRequest = if content_type.starts_with("application/json") {
710        match serde_json::from_slice(&body) {
711            Ok(r) => r,
712            Err(e) => return token_error("invalid_request", &format!("Invalid JSON: {e}")),
713        }
714    } else {
715        // Default to form-urlencoded (OAuth standard)
716        match serde_urlencoded::from_bytes(&body) {
717            Ok(r) => r,
718            Err(e) => return token_error("invalid_request", &format!("Invalid form data: {e}")),
719        }
720    };
721
722    match req.grant_type.as_str() {
723        "authorization_code" => handle_code_exchange(&state, &req).await,
724        "refresh_token" => handle_refresh(&state, &req).await,
725        _ => (
726            StatusCode::BAD_REQUEST,
727            Json(serde_json::json!({
728                "error": "unsupported_grant_type"
729            })),
730        )
731            .into_response(),
732    }
733}
734
735async fn handle_code_exchange(state: &OAuthState, req: &TokenRequest) -> Response {
736    let code = match &req.code {
737        Some(c) => c,
738        None => return token_error("invalid_request", "code is required"),
739    };
740    let code_verifier = match &req.code_verifier {
741        Some(v) => v,
742        None => return token_error("invalid_request", "code_verifier is required"),
743    };
744    let redirect_uri = match &req.redirect_uri {
745        Some(r) => r,
746        None => return token_error("invalid_request", "redirect_uri is required"),
747    };
748    let client_id = match &req.client_id {
749        Some(c) => c,
750        None => return token_error("invalid_request", "client_id is required"),
751    };
752
753    // Atomic exchange: mark code as used and fetch in one query (T8)
754    let row = sqlx::query_as::<_, (String, Uuid, String, String, String, chrono::DateTime<Utc>)>(
755        "UPDATE forge_oauth_codes SET used_at = now() \
756         WHERE code = $1 AND used_at IS NULL \
757         RETURNING client_id, user_id, redirect_uri, code_challenge, code_challenge_method, expires_at"
758    )
759    .bind(code)
760    .fetch_optional(&state.pool)
761    .await;
762
763    let (
764        stored_client_id,
765        user_id,
766        stored_redirect,
767        stored_challenge,
768        challenge_method,
769        expires_at,
770    ) = match row {
771        Ok(Some(r)) => r,
772        Ok(None) => {
773            return token_error(
774                "invalid_grant",
775                "Invalid or already used authorization code",
776            );
777        }
778        Err(e) => {
779            tracing::error!("Failed to exchange authorization code: {e}");
780            return token_error("server_error", "Failed to exchange code");
781        }
782    };
783
784    // Check expiry
785    if Utc::now() > expires_at {
786        return token_error("invalid_grant", "Authorization code has expired");
787    }
788
789    // Verify client_id matches (T14)
790    if *client_id != stored_client_id {
791        return token_error("invalid_grant", "client_id does not match");
792    }
793
794    // Verify redirect_uri matches
795    if *redirect_uri != stored_redirect {
796        return token_error("invalid_grant", "redirect_uri does not match");
797    }
798
799    if challenge_method != CHALLENGE_METHOD_S256 {
800        return token_error("invalid_request", "Unsupported code_challenge_method");
801    }
802    if !forge_core::oauth::pkce::verify_s256(code_verifier, &stored_challenge) {
803        return token_error("invalid_grant", "PKCE verification failed");
804    }
805
806    let access_ttl = state.access_token_ttl_secs;
807    let refresh_ttl = state.refresh_token_ttl_days;
808
809    let pair = forge_core::auth::tokens::issue_token_pair_with_client(
810        &state.pool,
811        user_id,
812        &["user"],
813        access_ttl,
814        refresh_ttl,
815        Some(client_id),
816        mcp_token_issuer(state.token_issuer.clone()),
817    )
818    .await;
819
820    match pair {
821        Ok(pair) => (
822            StatusCode::OK,
823            Json(TokenResponse {
824                access_token: pair.access_token,
825                token_type: "Bearer".into(),
826                expires_in: access_ttl,
827                refresh_token: pair.refresh_token,
828            }),
829        )
830            .into_response(),
831        Err(e) => {
832            tracing::error!("Failed to issue token pair: {e}");
833            token_error("server_error", "Failed to issue tokens")
834        }
835    }
836}
837
838async fn handle_refresh(state: &OAuthState, req: &TokenRequest) -> Response {
839    let refresh_token = match &req.refresh_token {
840        Some(t) => t,
841        None => return token_error("invalid_request", "refresh_token is required"),
842    };
843    let client_id = req.client_id.as_deref();
844
845    let access_ttl = state.access_token_ttl_secs;
846    let refresh_ttl = state.refresh_token_ttl_days;
847
848    let pair = forge_core::auth::tokens::rotate_refresh_token_with_client(
849        &state.pool,
850        refresh_token,
851        &["user"],
852        access_ttl,
853        refresh_ttl,
854        client_id,
855        mcp_token_issuer(state.token_issuer.clone()),
856    )
857    .await;
858
859    match pair {
860        Ok(pair) => (
861            StatusCode::OK,
862            Json(TokenResponse {
863                access_token: pair.access_token,
864                token_type: "Bearer".into(),
865                expires_in: access_ttl,
866                refresh_token: pair.refresh_token,
867            }),
868        )
869            .into_response(),
870        Err(_) => token_error("invalid_grant", "Invalid or expired refresh token"),
871    }
872}
873
874// ── Helpers ────────────────────────────────────────────────────────────
875
876/// Build a token-signing closure scoped to MCP audience.
877fn mcp_token_issuer(
878    issuer: Arc<dyn forge_core::TokenIssuer>,
879) -> impl FnOnce(Uuid, &[&str], i64) -> forge_core::Result<String> {
880    move |uid, roles, ttl| {
881        let claims = Claims::builder()
882            .subject(uid)
883            .roles(roles.iter().map(|s| s.to_string()).collect())
884            .claim("aud".to_string(), serde_json::json!(MCP_AUDIENCE))
885            .duration_secs(ttl)
886            .build()
887            .map_err(forge_core::ForgeError::Internal)?;
888        issuer.sign(&claims)
889    }
890}
891
892fn is_https(headers: &HeaderMap) -> bool {
893    headers
894        .get("x-forwarded-proto")
895        .and_then(|v| v.to_str().ok())
896        .map(|s| s == "https")
897        .unwrap_or(false)
898}
899
900fn token_error(error: &str, description: &str) -> Response {
901    (
902        StatusCode::BAD_REQUEST,
903        Json(serde_json::json!({
904            "error": error,
905            "error_description": description
906        })),
907    )
908        .into_response()
909}
910
911fn authorize_error_redirect(
912    redirect_uri: &str,
913    state: Option<&str>,
914    error: &str,
915    description: &str,
916) -> Response {
917    let mut url = format!(
918        "{}?error={}&error_description={}",
919        redirect_uri,
920        urlencoding(error),
921        urlencoding(description),
922    );
923    if let Some(st) = state {
924        url.push_str(&format!("&state={}", urlencoding(st)));
925    }
926    Redirect::to(&url).into_response()
927}
928
929fn base_url_from_headers(headers: &HeaderMap) -> String {
930    let host = headers
931        .get("host")
932        .and_then(|v| v.to_str().ok())
933        .unwrap_or("localhost:9081");
934
935    let scheme = headers
936        .get("x-forwarded-proto")
937        .and_then(|v| v.to_str().ok())
938        .unwrap_or("http");
939
940    format!("{scheme}://{host}")
941}
942
943fn client_ip(headers: &HeaderMap) -> String {
944    headers
945        .get("x-forwarded-for")
946        .and_then(|v| v.to_str().ok())
947        .and_then(|s| s.split(',').next())
948        .map(|s| s.trim().to_string())
949        .or_else(|| {
950            headers
951                .get("x-real-ip")
952                .and_then(|v| v.to_str().ok())
953                .map(String::from)
954        })
955        .unwrap_or_else(|| "unknown".to_string())
956}
957
958fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
959    headers
960        .get(header::COOKIE)
961        .and_then(|v| v.to_str().ok())
962        .and_then(|cookies| {
963            cookies.split(';').map(|c| c.trim()).find_map(|c| {
964                let (k, v) = c.split_once('=')?;
965                if k == name { Some(v.to_string()) } else { None }
966            })
967        })
968}
969
970fn html_escape(s: &str) -> String {
971    s.replace('&', "&amp;")
972        .replace('<', "&lt;")
973        .replace('>', "&gt;")
974        .replace('"', "&quot;")
975        .replace('\'', "&#x27;")
976}
977
978fn urlencoding(s: &str) -> String {
979    // Minimal percent-encoding for OAuth parameters
980    let mut result = String::with_capacity(s.len());
981    for b in s.bytes() {
982        match b {
983            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
984                result.push(b as char);
985            }
986            _ => {
987                result.push_str(&format!("%{b:02X}"));
988            }
989        }
990    }
991    result
992}