Skip to main content

forge_runtime/gateway/
oauth.rs

1//! OAuth 2.1 Authorization Server for MCP.
2//!
3//! Implements Authorization Code + PKCE flow so MCP clients like Claude Code
4//! can auto-authenticate via browser login.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use axum::Json;
11use axum::extract::{Query, State};
12use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
13use axum::response::{Html, IntoResponse, Redirect, Response};
14use chrono::Utc;
15use forge_core::auth::Claims;
16use forge_core::oauth::{self, validate_redirect_uri};
17use serde::{Deserialize, Serialize};
18use tokio::sync::RwLock;
19use uuid::Uuid;
20
21use super::auth::AuthMiddleware;
22
23const AUTHORIZE_PAGE: &str = include_str!("oauth_authorize.html");
24const AUTH_CODE_TTL_SECS: i64 = 60;
25const MAX_REGISTERED_CLIENTS: i64 = 1000;
26const CHALLENGE_METHOD_S256: &str = "S256";
27const MCP_AUDIENCE: &str = "forge:mcp";
28
29// Rate limiting constants
30const REGISTER_RATE_LIMIT: u32 = 10; // per minute per IP
31const LOGIN_FAIL_RATE_LIMIT: u32 = 5; // per minute per IP
32const RATE_WINDOW_SECS: u64 = 60;
33const RATE_CLEANUP_THRESHOLD: usize = 100;
34
35/// In-memory rate limiter for OAuth endpoints.
36#[derive(Clone, Default)]
37struct OAuthRateLimiter {
38    buckets: Arc<RwLock<HashMap<String, (u32, Instant)>>>,
39}
40
41impl OAuthRateLimiter {
42    async fn check(&self, key: &str, limit: u32) -> bool {
43        let mut buckets = self.buckets.write().await;
44        let now = Instant::now();
45        let window = Duration::from_secs(RATE_WINDOW_SECS);
46
47        // Purge stale entries periodically to prevent unbounded growth
48        if buckets.len() > RATE_CLEANUP_THRESHOLD {
49            buckets.retain(|_, (_, ts)| now.duration_since(*ts) <= window);
50        }
51
52        let entry = buckets.entry(key.to_string()).or_insert((0, now));
53        if now.duration_since(entry.1) > window {
54            *entry = (1, now);
55            return true;
56        }
57        if entry.0 >= limit {
58            return false;
59        }
60        entry.0 += 1;
61        true
62    }
63}
64
65/// Shared state for OAuth endpoints.
66#[derive(Clone)]
67pub struct OAuthState {
68    pool: sqlx::PgPool,
69    auth_middleware: Arc<AuthMiddleware>,
70    token_issuer: Arc<dyn forge_core::TokenIssuer>,
71    access_token_ttl_secs: i64,
72    refresh_token_ttl_days: i64,
73    auth_is_hmac: bool,
74    project_name: String,
75    jwt_secret: String,
76    rate_limiter: OAuthRateLimiter,
77    /// CSRF tokens: token -> expiry
78    csrf_tokens: Arc<RwLock<HashMap<String, Instant>>>,
79}
80
81impl OAuthState {
82    #[allow(clippy::too_many_arguments)]
83    pub fn new(
84        pool: sqlx::PgPool,
85        auth_middleware: Arc<AuthMiddleware>,
86        token_issuer: Arc<dyn forge_core::TokenIssuer>,
87        access_token_ttl_secs: i64,
88        refresh_token_ttl_days: i64,
89        auth_is_hmac: bool,
90        project_name: String,
91        jwt_secret: String,
92    ) -> Self {
93        Self {
94            pool,
95            auth_middleware,
96            token_issuer,
97            access_token_ttl_secs,
98            refresh_token_ttl_days,
99            auth_is_hmac,
100            project_name,
101            jwt_secret,
102            rate_limiter: OAuthRateLimiter::default(),
103            csrf_tokens: Arc::new(RwLock::new(HashMap::new())),
104        }
105    }
106
107    async fn store_csrf(&self, token: &str) {
108        let mut tokens = self.csrf_tokens.write().await;
109        let now = Instant::now();
110        let expiry = now + Duration::from_secs(600); // 10 min
111        tokens.insert(token.to_string(), expiry);
112        // Purge expired tokens periodically, not on every insert
113        if tokens.len() > RATE_CLEANUP_THRESHOLD {
114            tokens.retain(|_, exp| *exp > now);
115        }
116    }
117
118    async fn validate_csrf(&self, token: &str) -> bool {
119        let mut tokens = self.csrf_tokens.write().await;
120        if let Some(expiry) = tokens.remove(token) {
121            expiry > Instant::now()
122        } else {
123            false
124        }
125    }
126}
127
128// ── Well-known metadata endpoints ──────────────────────────────────────
129
130#[derive(Serialize)]
131pub struct AuthorizationServerMetadata {
132    issuer: String,
133    authorization_endpoint: String,
134    token_endpoint: String,
135    registration_endpoint: String,
136    response_types_supported: Vec<String>,
137    grant_types_supported: Vec<String>,
138    code_challenge_methods_supported: Vec<String>,
139    token_endpoint_auth_methods_supported: Vec<String>,
140}
141
142pub async fn well_known_oauth_metadata(
143    headers: HeaderMap,
144    State(_state): State<Arc<OAuthState>>,
145) -> Json<AuthorizationServerMetadata> {
146    let base = base_url_from_headers(&headers);
147    Json(AuthorizationServerMetadata {
148        issuer: base.clone(),
149        authorization_endpoint: format!("{base}/_api/oauth/authorize"),
150        token_endpoint: format!("{base}/_api/oauth/token"),
151        registration_endpoint: format!("{base}/_api/oauth/register"),
152        response_types_supported: vec!["code".into()],
153        grant_types_supported: vec!["authorization_code".into(), "refresh_token".into()],
154        code_challenge_methods_supported: vec![CHALLENGE_METHOD_S256.into()],
155        token_endpoint_auth_methods_supported: vec!["none".into()],
156    })
157}
158
159#[derive(Serialize)]
160pub struct ProtectedResourceMetadata {
161    resource: String,
162    authorization_servers: Vec<String>,
163}
164
165pub async fn well_known_resource_metadata(
166    headers: HeaderMap,
167    State(_state): State<Arc<OAuthState>>,
168) -> Json<ProtectedResourceMetadata> {
169    let base = base_url_from_headers(&headers);
170    Json(ProtectedResourceMetadata {
171        resource: base.clone(),
172        authorization_servers: vec![base],
173    })
174}
175
176// ── Dynamic client registration ────────────────────────────────────────
177
178#[derive(Deserialize)]
179pub struct RegisterRequest {
180    pub client_name: Option<String>,
181    pub redirect_uris: Vec<String>,
182    #[serde(default)]
183    pub grant_types: Vec<String>,
184    #[serde(default)]
185    pub token_endpoint_auth_method: Option<String>,
186}
187
188#[derive(Serialize)]
189pub struct RegisterResponse {
190    pub client_id: String,
191    pub client_name: Option<String>,
192    pub redirect_uris: Vec<String>,
193    pub grant_types: Vec<String>,
194    pub token_endpoint_auth_method: String,
195}
196
197pub async fn oauth_register(
198    headers: HeaderMap,
199    State(state): State<Arc<OAuthState>>,
200    Json(req): Json<RegisterRequest>,
201) -> Response {
202    let ip = client_ip(&headers);
203    let rate_key = format!("oauth_register:{ip}");
204    if !state
205        .rate_limiter
206        .check(&rate_key, REGISTER_RATE_LIMIT)
207        .await
208    {
209        return (
210            StatusCode::TOO_MANY_REQUESTS,
211            Json(serde_json::json!({
212                "error": "too_many_requests",
213                "error_description": "Rate limit exceeded for client registration"
214            })),
215        )
216            .into_response();
217    }
218
219    // Check client cap
220    let count: i64 = sqlx::query_scalar!("SELECT COUNT(*) FROM forge_oauth_clients")
221        .fetch_one(&state.pool)
222        .await
223        .unwrap_or(Some(0))
224        .unwrap_or(0);
225    if count >= MAX_REGISTERED_CLIENTS {
226        return (
227            StatusCode::BAD_REQUEST,
228            Json(serde_json::json!({
229                "error": "too_many_clients",
230                "error_description": "Maximum number of registered clients reached"
231            })),
232        )
233            .into_response();
234    }
235
236    if req.redirect_uris.is_empty() {
237        return (
238            StatusCode::BAD_REQUEST,
239            Json(serde_json::json!({
240                "error": "invalid_client_metadata",
241                "error_description": "redirect_uris is required"
242            })),
243        )
244            .into_response();
245    }
246
247    // Validate redirect URIs: reject fragments, require HTTPS for non-localhost
248    for uri in &req.redirect_uris {
249        // Reject URIs with fragments (per OAuth 2.1 spec)
250        if uri.contains('#') {
251            return (
252                StatusCode::BAD_REQUEST,
253                Json(serde_json::json!({
254                    "error": "invalid_redirect_uri",
255                    "error_description": "redirect_uri must not contain a fragment"
256                })),
257            )
258                .into_response();
259        }
260        // Check scheme and host
261        let is_localhost = uri.starts_with("http://localhost")
262            || uri.starts_with("http://127.0.0.1")
263            || uri.starts_with("http://[::1]");
264        let is_https = uri.starts_with("https://");
265        if !is_localhost && !is_https {
266            return (
267                StatusCode::BAD_REQUEST,
268                Json(serde_json::json!({
269                    "error": "invalid_redirect_uri",
270                    "error_description": "redirect_uri must use HTTPS for non-localhost URIs"
271                })),
272            )
273                .into_response();
274        }
275    }
276
277    let client_id = Uuid::new_v4().to_string();
278    let auth_method = req.token_endpoint_auth_method.as_deref().unwrap_or("none");
279
280    let result = sqlx::query!(
281        "INSERT INTO forge_oauth_clients (client_id, client_name, redirect_uris, token_endpoint_auth_method) \
282         VALUES ($1, $2, $3, $4)",
283        &client_id,
284        req.client_name as _,
285        &req.redirect_uris,
286        auth_method,
287    )
288    .execute(&state.pool)
289    .await;
290
291    if let Err(e) = result {
292        tracing::error!("Failed to register OAuth client: {e}");
293        return (
294            StatusCode::INTERNAL_SERVER_ERROR,
295            Json(serde_json::json!({
296                "error": "server_error",
297                "error_description": "Failed to register client"
298            })),
299        )
300            .into_response();
301    }
302
303    let grant_types = if req.grant_types.is_empty() {
304        vec!["authorization_code".into()]
305    } else {
306        req.grant_types
307    };
308
309    (
310        StatusCode::CREATED,
311        Json(RegisterResponse {
312            client_id,
313            client_name: req.client_name,
314            redirect_uris: req.redirect_uris,
315            grant_types,
316            token_endpoint_auth_method: auth_method.to_string(),
317        }),
318    )
319        .into_response()
320}
321
322// ── Authorization endpoint ─────────────────────────────────────────────
323
324#[derive(Deserialize)]
325pub struct AuthorizeQuery {
326    pub client_id: String,
327    pub redirect_uri: String,
328    pub code_challenge: String,
329    #[serde(default = "default_s256")]
330    pub code_challenge_method: String,
331    pub state: Option<String>,
332    pub scope: Option<String>,
333    pub response_type: Option<String>,
334}
335
336fn default_s256() -> String {
337    CHALLENGE_METHOD_S256.into()
338}
339
340pub async fn oauth_authorize_get(
341    headers: HeaderMap,
342    Query(params): Query<AuthorizeQuery>,
343    State(state): State<Arc<OAuthState>>,
344) -> Response {
345    // Validate client_id
346    let client = sqlx::query!(
347        "SELECT client_id, client_name, redirect_uris FROM forge_oauth_clients WHERE client_id = $1",
348        &params.client_id,
349    )
350    .fetch_optional(&state.pool)
351    .await;
352
353    let (_, client_name, redirect_uris) = match client {
354        Ok(Some(c)) => (c.client_id, c.client_name, c.redirect_uris),
355        Ok(None) => {
356            return (
357                StatusCode::BAD_REQUEST,
358                Json(serde_json::json!({
359                    "error": "invalid_client",
360                    "error_description": "Unknown client_id"
361                })),
362            )
363                .into_response();
364        }
365        Err(e) => {
366            tracing::error!("OAuth client lookup failed: {e}");
367            return (
368                StatusCode::INTERNAL_SERVER_ERROR,
369                Json(serde_json::json!({
370                    "error": "server_error"
371                })),
372            )
373                .into_response();
374        }
375    };
376
377    // Validate redirect_uri (exact match, T2)
378    if !validate_redirect_uri(&params.redirect_uri, &redirect_uris) {
379        return (
380            StatusCode::BAD_REQUEST,
381            Json(serde_json::json!({
382                "error": "invalid_redirect_uri",
383                "error_description": "redirect_uri does not match any registered URI"
384            })),
385        )
386            .into_response();
387    }
388
389    if params.code_challenge_method != CHALLENGE_METHOD_S256 {
390        return (
391            StatusCode::BAD_REQUEST,
392            Json(serde_json::json!({
393                "error": "invalid_request",
394                "error_description": "Only S256 code_challenge_method is supported"
395            })),
396        )
397            .into_response();
398    }
399
400    // Check for existing session cookie (set by auth middleware on API calls)
401    let session_subject = extract_cookie(&headers, "forge_session")
402        .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
403    let has_session = session_subject.is_some();
404
405    // Generate CSRF token (T5)
406    let csrf_token = oauth::generate_random_token();
407    state.store_csrf(&csrf_token).await;
408
409    let auth_mode = if has_session {
410        "session" // user is known from cookie, show consent directly
411    } else if state.auth_is_hmac {
412        "hmac" // show email/password form
413    } else {
414        "external" // show "log in to your app first"
415    };
416    let display_name = client_name.as_deref().unwrap_or(&params.client_id);
417
418    let html = AUTHORIZE_PAGE
419        .replace("{{app_name}}", &html_escape(&state.project_name))
420        .replace("{{client_name}}", &html_escape(display_name))
421        .replace("{{csrf_token}}", &csrf_token)
422        .replace("{{client_id}}", &html_escape(&params.client_id))
423        .replace("{{redirect_uri}}", &html_escape(&params.redirect_uri))
424        .replace("{{code_challenge}}", &html_escape(&params.code_challenge))
425        .replace(
426            "{{code_challenge_method}}",
427            &html_escape(&params.code_challenge_method),
428        )
429        .replace(
430            "{{state}}",
431            &html_escape(params.state.as_deref().unwrap_or("")),
432        )
433        .replace(
434            "{{scope}}",
435            &html_escape(params.scope.as_deref().unwrap_or("")),
436        )
437        .replace("{{auth_mode}}", auth_mode)
438        .replace("{{authorize_url}}", "/_api/oauth/authorize")
439        .replace("{{error_message}}", "");
440
441    let mut response = (StatusCode::OK, Html(html)).into_response();
442    // T17: clickjacking protection
443    response
444        .headers_mut()
445        .insert("X-Frame-Options", HeaderValue::from_static("DENY"));
446    response.headers_mut().insert(
447        "Content-Security-Policy",
448        HeaderValue::from_static("frame-ancestors 'none'"),
449    );
450    // Set CSRF cookie (T1, T5)
451    let csrf_secure_flag = if is_https(&headers) { "; Secure" } else { "" };
452    let cookie = format!(
453        "forge_oauth_csrf={csrf_token}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=600{csrf_secure_flag}"
454    );
455    if let Ok(cookie_val) = HeaderValue::from_str(&cookie) {
456        response
457            .headers_mut()
458            .insert(header::SET_COOKIE, cookie_val);
459    }
460    response
461}
462
463#[derive(Deserialize)]
464pub struct AuthorizeForm {
465    pub csrf_token: String,
466    pub client_id: String,
467    pub redirect_uri: String,
468    pub code_challenge: String,
469    pub code_challenge_method: String,
470    pub state: Option<String>,
471    pub scope: Option<String>,
472    pub response_type: Option<String>,
473    // Consent flow: existing token from localStorage
474    pub token: Option<String>,
475    // Login flow: email/password
476    pub email: Option<String>,
477    pub password: Option<String>,
478}
479
480pub async fn oauth_authorize_post(
481    headers: HeaderMap,
482    State(state): State<Arc<OAuthState>>,
483    axum::Form(form): axum::Form<AuthorizeForm>,
484) -> Response {
485    // Validate CSRF (T5): check both cookie and form value
486    let csrf_from_cookie = extract_cookie(&headers, "forge_oauth_csrf");
487    let csrf_valid = if let Some(cookie_csrf) = csrf_from_cookie {
488        cookie_csrf == form.csrf_token && state.validate_csrf(&form.csrf_token).await
489    } else {
490        false
491    };
492    if !csrf_valid {
493        return (
494            StatusCode::FORBIDDEN,
495            Json(serde_json::json!({
496                "error": "csrf_validation_failed",
497                "error_description": "Invalid or expired CSRF token. Please try again."
498            })),
499        )
500            .into_response();
501    }
502
503    // Rate limit login failures (T7)
504    let ip = client_ip(&headers);
505    let rate_key = format!("oauth_login:{ip}");
506
507    // Validate client and redirect_uri again (form could be tampered)
508    let client = sqlx::query!(
509        "SELECT redirect_uris FROM forge_oauth_clients WHERE client_id = $1",
510        &form.client_id,
511    )
512    .fetch_optional(&state.pool)
513    .await;
514
515    let redirect_uris = match client {
516        Ok(Some(c)) => c.redirect_uris,
517        _ => {
518            return (
519                StatusCode::BAD_REQUEST,
520                Json(serde_json::json!({
521                    "error": "invalid_client"
522                })),
523            )
524                .into_response();
525        }
526    };
527
528    if !validate_redirect_uri(&form.redirect_uri, &redirect_uris) {
529        return (
530            StatusCode::BAD_REQUEST,
531            Json(serde_json::json!({
532                "error": "invalid_redirect_uri"
533            })),
534        )
535            .into_response();
536    }
537
538    // Authenticate user: try session cookie first, then token, then email/password
539    let user_id: Uuid;
540
541    let session_subject = extract_cookie(&headers, "forge_session")
542        .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
543
544    if let Some(subject) = session_subject {
545        // Session cookie flow: user identified by signed cookie from previous API calls.
546        // Subject may be a UUID (HMAC auth) or an external provider ID (Firebase, Clerk).
547        user_id = subject.parse::<Uuid>().unwrap_or_else(|_| {
548            // Non-UUID subject (Firebase UID, etc.): deterministic UUID from subject hash.
549            use sha2::Digest;
550            let hash: [u8; 32] = sha2::Sha256::digest(subject.as_bytes()).into();
551            let mut bytes = [0u8; 16];
552            bytes.copy_from_slice(&hash[..16]);
553            Uuid::from_bytes(bytes)
554        });
555    } else if let Some(token) = &form.token {
556        // Consent flow: validate existing JWT
557        match state.auth_middleware.validate_token_async(token).await {
558            Ok(claims) => {
559                user_id = claims
560                    .user_id()
561                    .ok_or(())
562                    .map_err(|_| ())
563                    .unwrap_or_default();
564                if user_id.is_nil() {
565                    return authorize_error_redirect(
566                        &form.redirect_uri,
567                        form.state.as_deref(),
568                        "access_denied",
569                        "Invalid user identity in token",
570                    );
571                }
572            }
573            Err(_) => {
574                return authorize_error_redirect(
575                    &form.redirect_uri,
576                    form.state.as_deref(),
577                    "access_denied",
578                    "Invalid or expired token. Please log in again.",
579                );
580            }
581        }
582    } else if let (Some(email), Some(password)) = (&form.email, &form.password) {
583        // Login flow (HMAC mode only)
584        if !state.auth_is_hmac {
585            return authorize_error_redirect(
586                &form.redirect_uri,
587                form.state.as_deref(),
588                "access_denied",
589                "Direct login not supported with external auth provider",
590            );
591        }
592
593        if !state
594            .rate_limiter
595            .check(&rate_key, LOGIN_FAIL_RATE_LIMIT)
596            .await
597        {
598            return authorize_error_redirect(
599                &form.redirect_uri,
600                form.state.as_deref(),
601                "access_denied",
602                "Too many login attempts. Please try again later.",
603            );
604        }
605
606        // Query users table by convention
607        let row = sqlx::query!(
608            "SELECT id, password_hash, role::TEXT FROM users WHERE email = $1",
609            email,
610        )
611        .fetch_optional(&state.pool)
612        .await;
613
614        match row {
615            Ok(Some(r)) if r.password_hash.is_some() => {
616                match bcrypt::verify(
617                    password,
618                    r.password_hash.as_ref().expect("guarded by is_some check"),
619                ) {
620                    Ok(true) => {
621                        user_id = r.id;
622                    }
623                    _ => {
624                        return authorize_error_redirect(
625                            &form.redirect_uri,
626                            form.state.as_deref(),
627                            "access_denied",
628                            "Invalid email or password",
629                        );
630                    }
631                }
632            }
633            _ => {
634                return authorize_error_redirect(
635                    &form.redirect_uri,
636                    form.state.as_deref(),
637                    "access_denied",
638                    "Invalid email or password",
639                );
640            }
641        }
642    } else {
643        return (
644            StatusCode::BAD_REQUEST,
645            Json(serde_json::json!({
646                "error": "invalid_request",
647                "error_description": "Must provide either a token or email/password"
648            })),
649        )
650            .into_response();
651    }
652
653    // Generate authorization code
654    let code = oauth::generate_random_token();
655    let expires_at = Utc::now() + chrono::Duration::seconds(AUTH_CODE_TTL_SECS);
656    let scopes: Vec<String> = form
657        .scope
658        .as_deref()
659        .map(|s| s.split_whitespace().map(String::from).collect())
660        .unwrap_or_default();
661
662    let result = sqlx::query!(
663        "INSERT INTO forge_oauth_codes \
664         (code, client_id, user_id, redirect_uri, code_challenge, code_challenge_method, scopes, expires_at) \
665         VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
666        &code,
667        &form.client_id,
668        user_id,
669        &form.redirect_uri,
670        &form.code_challenge,
671        &form.code_challenge_method,
672        &scopes,
673        expires_at,
674    )
675    .execute(&state.pool)
676    .await;
677
678    if let Err(e) = result {
679        tracing::error!("Failed to store authorization code: {e}");
680        return authorize_error_redirect(
681            &form.redirect_uri,
682            form.state.as_deref(),
683            "server_error",
684            "Failed to generate authorization code",
685        );
686    }
687
688    // Redirect with code (T18: Referrer-Policy)
689    let mut redirect_url = format!("{}?code={}", form.redirect_uri, urlencoding(&code));
690    if let Some(st) = &form.state {
691        redirect_url.push_str(&format!("&state={}", urlencoding(st)));
692    }
693
694    let mut response = Redirect::to(&redirect_url).into_response();
695    response
696        .headers_mut()
697        .insert("Referrer-Policy", HeaderValue::from_static("no-referrer"));
698
699    // Set session cookie so the next authorize visit shows consent directly
700    // instead of the login form. This is same-origin (backend serves both
701    // the authorize page and this POST), so the cookie sticks.
702    let cookie_value = super::auth::sign_session_cookie(&user_id.to_string(), &state.jwt_secret);
703    let secure_flag = if is_https(&headers) { "; Secure" } else { "" };
704    let session_cookie = format!(
705        "forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=86400{secure_flag}"
706    );
707    if let Ok(val) = HeaderValue::from_str(&session_cookie) {
708        response.headers_mut().append(header::SET_COOKIE, val);
709    }
710
711    response
712}
713
714// ── Token endpoint ─────────────────────────────────────────────────────
715
716#[derive(Deserialize)]
717pub struct TokenRequest {
718    pub grant_type: String,
719    pub code: Option<String>,
720    pub redirect_uri: Option<String>,
721    pub code_verifier: Option<String>,
722    pub client_id: Option<String>,
723    pub refresh_token: Option<String>,
724}
725
726#[derive(Serialize)]
727pub struct TokenResponse {
728    pub access_token: String,
729    pub token_type: String,
730    pub expires_in: i64,
731    pub refresh_token: String,
732}
733
734/// Token endpoint accepts both `application/json` and
735/// `application/x-www-form-urlencoded` (OAuth 2.1 standard).
736pub async fn oauth_token(
737    State(state): State<Arc<OAuthState>>,
738    headers: HeaderMap,
739    body: axum::body::Bytes,
740) -> Response {
741    let content_type = headers
742        .get(header::CONTENT_TYPE)
743        .and_then(|v| v.to_str().ok())
744        .unwrap_or("");
745
746    let req: TokenRequest = if content_type.starts_with("application/json") {
747        match serde_json::from_slice(&body) {
748            Ok(r) => r,
749            Err(e) => return token_error("invalid_request", &format!("Invalid JSON: {e}")),
750        }
751    } else {
752        // Default to form-urlencoded (OAuth standard)
753        match serde_urlencoded::from_bytes(&body) {
754            Ok(r) => r,
755            Err(e) => return token_error("invalid_request", &format!("Invalid form data: {e}")),
756        }
757    };
758
759    match req.grant_type.as_str() {
760        "authorization_code" => handle_code_exchange(&state, &req).await,
761        "refresh_token" => handle_refresh(&state, &req).await,
762        _ => (
763            StatusCode::BAD_REQUEST,
764            Json(serde_json::json!({
765                "error": "unsupported_grant_type"
766            })),
767        )
768            .into_response(),
769    }
770}
771
772async fn handle_code_exchange(state: &OAuthState, req: &TokenRequest) -> Response {
773    let code = match &req.code {
774        Some(c) => c,
775        None => return token_error("invalid_request", "code is required"),
776    };
777    let code_verifier = match &req.code_verifier {
778        Some(v) => v,
779        None => return token_error("invalid_request", "code_verifier is required"),
780    };
781    let redirect_uri = match &req.redirect_uri {
782        Some(r) => r,
783        None => return token_error("invalid_request", "redirect_uri is required"),
784    };
785    let client_id = match &req.client_id {
786        Some(c) => c,
787        None => return token_error("invalid_request", "client_id is required"),
788    };
789
790    // Atomic exchange: mark code as used and fetch in one query (T8)
791    let row = sqlx::query!(
792        "UPDATE forge_oauth_codes SET used_at = now() \
793         WHERE code = $1 AND used_at IS NULL \
794         RETURNING client_id, user_id, redirect_uri, code_challenge, code_challenge_method, expires_at",
795        code,
796    )
797    .fetch_optional(&state.pool)
798    .await;
799
800    let (
801        stored_client_id,
802        user_id,
803        stored_redirect,
804        stored_challenge,
805        challenge_method,
806        expires_at,
807    ) = match row {
808        Ok(Some(r)) => (
809            r.client_id,
810            r.user_id,
811            r.redirect_uri,
812            r.code_challenge,
813            r.code_challenge_method,
814            r.expires_at,
815        ),
816        Ok(None) => {
817            return token_error(
818                "invalid_grant",
819                "Invalid or already used authorization code",
820            );
821        }
822        Err(e) => {
823            tracing::error!("Failed to exchange authorization code: {e}");
824            return token_error("server_error", "Failed to exchange code");
825        }
826    };
827
828    // Check expiry
829    if Utc::now() > expires_at {
830        return token_error("invalid_grant", "Authorization code has expired");
831    }
832
833    // Verify client_id matches (T14)
834    if *client_id != stored_client_id {
835        return token_error("invalid_grant", "client_id does not match");
836    }
837
838    // Verify redirect_uri matches
839    if *redirect_uri != stored_redirect {
840        return token_error("invalid_grant", "redirect_uri does not match");
841    }
842
843    if challenge_method != CHALLENGE_METHOD_S256 {
844        return token_error("invalid_request", "Unsupported code_challenge_method");
845    }
846    if !forge_core::oauth::pkce::verify_s256(code_verifier, &stored_challenge) {
847        return token_error("invalid_grant", "PKCE verification failed");
848    }
849
850    let access_ttl = state.access_token_ttl_secs;
851    let refresh_ttl = state.refresh_token_ttl_days;
852
853    let pair = forge_core::auth::tokens::issue_token_pair_with_client(
854        &state.pool,
855        user_id,
856        &["user"],
857        access_ttl,
858        refresh_ttl,
859        Some(client_id),
860        mcp_token_issuer(state.token_issuer.clone()),
861    )
862    .await;
863
864    match pair {
865        Ok(pair) => (
866            StatusCode::OK,
867            Json(TokenResponse {
868                access_token: pair.access_token,
869                token_type: "Bearer".into(),
870                expires_in: access_ttl,
871                refresh_token: pair.refresh_token,
872            }),
873        )
874            .into_response(),
875        Err(e) => {
876            tracing::error!("Failed to issue token pair: {e}");
877            token_error("server_error", "Failed to issue tokens")
878        }
879    }
880}
881
882async fn handle_refresh(state: &OAuthState, req: &TokenRequest) -> Response {
883    let refresh_token = match &req.refresh_token {
884        Some(t) => t,
885        None => return token_error("invalid_request", "refresh_token is required"),
886    };
887    let client_id = req.client_id.as_deref();
888
889    let access_ttl = state.access_token_ttl_secs;
890    let refresh_ttl = state.refresh_token_ttl_days;
891
892    let pair = forge_core::auth::tokens::rotate_refresh_token_with_client(
893        &state.pool,
894        refresh_token,
895        &["user"],
896        access_ttl,
897        refresh_ttl,
898        client_id,
899        mcp_token_issuer(state.token_issuer.clone()),
900    )
901    .await;
902
903    match pair {
904        Ok(pair) => (
905            StatusCode::OK,
906            Json(TokenResponse {
907                access_token: pair.access_token,
908                token_type: "Bearer".into(),
909                expires_in: access_ttl,
910                refresh_token: pair.refresh_token,
911            }),
912        )
913            .into_response(),
914        Err(_) => token_error("invalid_grant", "Invalid or expired refresh token"),
915    }
916}
917
918// ── Helpers ────────────────────────────────────────────────────────────
919
920/// Build a token-signing closure scoped to MCP audience.
921fn mcp_token_issuer(
922    issuer: Arc<dyn forge_core::TokenIssuer>,
923) -> impl FnOnce(Uuid, &[&str], i64) -> forge_core::Result<String> {
924    move |uid, roles, ttl| {
925        let claims = Claims::builder()
926            .subject(uid)
927            .roles(roles.iter().map(|s| s.to_string()).collect())
928            .claim("aud".to_string(), serde_json::json!(MCP_AUDIENCE))
929            .duration_secs(ttl)
930            .build()
931            .map_err(forge_core::ForgeError::Internal)?;
932        issuer.sign(&claims)
933    }
934}
935
936fn is_https(headers: &HeaderMap) -> bool {
937    headers
938        .get("x-forwarded-proto")
939        .and_then(|v| v.to_str().ok())
940        .map(|s| s == "https")
941        .unwrap_or(false)
942}
943
944fn token_error(error: &str, description: &str) -> Response {
945    (
946        StatusCode::BAD_REQUEST,
947        Json(serde_json::json!({
948            "error": error,
949            "error_description": description
950        })),
951    )
952        .into_response()
953}
954
955fn authorize_error_redirect(
956    redirect_uri: &str,
957    state: Option<&str>,
958    error: &str,
959    description: &str,
960) -> Response {
961    let mut url = format!(
962        "{}?error={}&error_description={}",
963        redirect_uri,
964        urlencoding(error),
965        urlencoding(description),
966    );
967    if let Some(st) = state {
968        url.push_str(&format!("&state={}", urlencoding(st)));
969    }
970    Redirect::to(&url).into_response()
971}
972
973fn base_url_from_headers(headers: &HeaderMap) -> String {
974    let host = headers
975        .get("host")
976        .and_then(|v| v.to_str().ok())
977        .unwrap_or("localhost:9081");
978
979    let scheme = headers
980        .get("x-forwarded-proto")
981        .and_then(|v| v.to_str().ok())
982        .unwrap_or("http");
983
984    format!("{scheme}://{host}")
985}
986
987fn client_ip(headers: &HeaderMap) -> String {
988    headers
989        .get("x-forwarded-for")
990        .and_then(|v| v.to_str().ok())
991        .and_then(|s| s.split(',').next())
992        .map(|s| s.trim().to_string())
993        .or_else(|| {
994            headers
995                .get("x-real-ip")
996                .and_then(|v| v.to_str().ok())
997                .map(String::from)
998        })
999        .unwrap_or_else(|| "unknown".to_string())
1000}
1001
1002fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
1003    headers
1004        .get(header::COOKIE)
1005        .and_then(|v| v.to_str().ok())
1006        .and_then(|cookies| {
1007            cookies.split(';').map(|c| c.trim()).find_map(|c| {
1008                let (k, v) = c.split_once('=')?;
1009                if k == name { Some(v.to_string()) } else { None }
1010            })
1011        })
1012}
1013
1014fn html_escape(s: &str) -> String {
1015    s.replace('&', "&amp;")
1016        .replace('<', "&lt;")
1017        .replace('>', "&gt;")
1018        .replace('"', "&quot;")
1019        .replace('\'', "&#x27;")
1020}
1021
1022fn urlencoding(s: &str) -> String {
1023    // Minimal percent-encoding for OAuth parameters
1024    let mut result = String::with_capacity(s.len());
1025    for b in s.bytes() {
1026        match b {
1027            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1028                result.push(b as char);
1029            }
1030            _ => {
1031                result.push_str(&format!("%{b:02X}"));
1032            }
1033        }
1034    }
1035    result
1036}