1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use axum::Json;
11use axum::extract::{Query, State};
12use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
13use axum::response::{Html, IntoResponse, Redirect, Response};
14use chrono::Utc;
15use forge_core::auth::Claims;
16use forge_core::oauth::{self, validate_redirect_uri};
17use serde::{Deserialize, Serialize};
18use tokio::sync::RwLock;
19use uuid::Uuid;
20
21use super::auth::AuthMiddleware;
22
23const AUTHORIZE_PAGE: &str = include_str!("oauth_authorize.html");
24const AUTH_CODE_TTL_SECS: i64 = 60;
25const MAX_REGISTERED_CLIENTS: i64 = 1000;
26const CHALLENGE_METHOD_S256: &str = "S256";
27const MCP_AUDIENCE: &str = "forge:mcp";
28
29const REGISTER_RATE_LIMIT: u32 = 10; const LOGIN_FAIL_RATE_LIMIT: u32 = 5; const RATE_WINDOW_SECS: u64 = 60;
33const RATE_CLEANUP_THRESHOLD: usize = 100;
34
35#[derive(Clone, Default)]
37struct OAuthRateLimiter {
38 buckets: Arc<RwLock<HashMap<String, (u32, Instant)>>>,
39}
40
41impl OAuthRateLimiter {
42 async fn check(&self, key: &str, limit: u32) -> bool {
43 let mut buckets = self.buckets.write().await;
44 let now = Instant::now();
45 let window = Duration::from_secs(RATE_WINDOW_SECS);
46
47 if buckets.len() > RATE_CLEANUP_THRESHOLD {
49 buckets.retain(|_, (_, ts)| now.duration_since(*ts) <= window);
50 }
51
52 let entry = buckets.entry(key.to_string()).or_insert((0, now));
53 if now.duration_since(entry.1) > window {
54 *entry = (1, now);
55 return true;
56 }
57 if entry.0 >= limit {
58 return false;
59 }
60 entry.0 += 1;
61 true
62 }
63}
64
65#[derive(Clone)]
67pub struct OAuthState {
68 pool: sqlx::PgPool,
69 auth_middleware: Arc<AuthMiddleware>,
70 token_issuer: Arc<dyn forge_core::TokenIssuer>,
71 access_token_ttl_secs: i64,
72 refresh_token_ttl_days: i64,
73 auth_is_hmac: bool,
74 project_name: String,
75 jwt_secret: String,
76 rate_limiter: OAuthRateLimiter,
77 csrf_tokens: Arc<RwLock<HashMap<String, Instant>>>,
79}
80
81impl OAuthState {
82 #[allow(clippy::too_many_arguments)]
83 pub fn new(
84 pool: sqlx::PgPool,
85 auth_middleware: Arc<AuthMiddleware>,
86 token_issuer: Arc<dyn forge_core::TokenIssuer>,
87 access_token_ttl_secs: i64,
88 refresh_token_ttl_days: i64,
89 auth_is_hmac: bool,
90 project_name: String,
91 jwt_secret: String,
92 ) -> Self {
93 Self {
94 pool,
95 auth_middleware,
96 token_issuer,
97 access_token_ttl_secs,
98 refresh_token_ttl_days,
99 auth_is_hmac,
100 project_name,
101 jwt_secret,
102 rate_limiter: OAuthRateLimiter::default(),
103 csrf_tokens: Arc::new(RwLock::new(HashMap::new())),
104 }
105 }
106
107 async fn store_csrf(&self, token: &str) {
108 let mut tokens = self.csrf_tokens.write().await;
109 let now = Instant::now();
110 let expiry = now + Duration::from_secs(600); tokens.insert(token.to_string(), expiry);
112 if tokens.len() > RATE_CLEANUP_THRESHOLD {
114 tokens.retain(|_, exp| *exp > now);
115 }
116 }
117
118 async fn validate_csrf(&self, token: &str) -> bool {
119 let mut tokens = self.csrf_tokens.write().await;
120 if let Some(expiry) = tokens.remove(token) {
121 expiry > Instant::now()
122 } else {
123 false
124 }
125 }
126}
127
128#[derive(Serialize)]
131pub struct AuthorizationServerMetadata {
132 issuer: String,
133 authorization_endpoint: String,
134 token_endpoint: String,
135 registration_endpoint: String,
136 response_types_supported: Vec<String>,
137 grant_types_supported: Vec<String>,
138 code_challenge_methods_supported: Vec<String>,
139 token_endpoint_auth_methods_supported: Vec<String>,
140}
141
142pub async fn well_known_oauth_metadata(
143 headers: HeaderMap,
144 State(_state): State<Arc<OAuthState>>,
145) -> Json<AuthorizationServerMetadata> {
146 let base = base_url_from_headers(&headers);
147 Json(AuthorizationServerMetadata {
148 issuer: base.clone(),
149 authorization_endpoint: format!("{base}/_api/oauth/authorize"),
150 token_endpoint: format!("{base}/_api/oauth/token"),
151 registration_endpoint: format!("{base}/_api/oauth/register"),
152 response_types_supported: vec!["code".into()],
153 grant_types_supported: vec!["authorization_code".into(), "refresh_token".into()],
154 code_challenge_methods_supported: vec![CHALLENGE_METHOD_S256.into()],
155 token_endpoint_auth_methods_supported: vec!["none".into()],
156 })
157}
158
159#[derive(Serialize)]
160pub struct ProtectedResourceMetadata {
161 resource: String,
162 authorization_servers: Vec<String>,
163}
164
165pub async fn well_known_resource_metadata(
166 headers: HeaderMap,
167 State(_state): State<Arc<OAuthState>>,
168) -> Json<ProtectedResourceMetadata> {
169 let base = base_url_from_headers(&headers);
170 Json(ProtectedResourceMetadata {
171 resource: base.clone(),
172 authorization_servers: vec![base],
173 })
174}
175
176#[derive(Deserialize)]
179pub struct RegisterRequest {
180 pub client_name: Option<String>,
181 pub redirect_uris: Vec<String>,
182 #[serde(default)]
183 pub grant_types: Vec<String>,
184 #[serde(default)]
185 pub token_endpoint_auth_method: Option<String>,
186}
187
188#[derive(Serialize)]
189pub struct RegisterResponse {
190 pub client_id: String,
191 pub client_name: Option<String>,
192 pub redirect_uris: Vec<String>,
193 pub grant_types: Vec<String>,
194 pub token_endpoint_auth_method: String,
195}
196
197pub async fn oauth_register(
198 headers: HeaderMap,
199 State(state): State<Arc<OAuthState>>,
200 Json(req): Json<RegisterRequest>,
201) -> Response {
202 let ip = client_ip(&headers);
203 let rate_key = format!("oauth_register:{ip}");
204 if !state
205 .rate_limiter
206 .check(&rate_key, REGISTER_RATE_LIMIT)
207 .await
208 {
209 return (
210 StatusCode::TOO_MANY_REQUESTS,
211 Json(serde_json::json!({
212 "error": "too_many_requests",
213 "error_description": "Rate limit exceeded for client registration"
214 })),
215 )
216 .into_response();
217 }
218
219 let count: i64 = sqlx::query_scalar!("SELECT COUNT(*) FROM forge_oauth_clients")
221 .fetch_one(&state.pool)
222 .await
223 .unwrap_or(Some(0))
224 .unwrap_or(0);
225 if count >= MAX_REGISTERED_CLIENTS {
226 return (
227 StatusCode::BAD_REQUEST,
228 Json(serde_json::json!({
229 "error": "too_many_clients",
230 "error_description": "Maximum number of registered clients reached"
231 })),
232 )
233 .into_response();
234 }
235
236 if req.redirect_uris.is_empty() {
237 return (
238 StatusCode::BAD_REQUEST,
239 Json(serde_json::json!({
240 "error": "invalid_client_metadata",
241 "error_description": "redirect_uris is required"
242 })),
243 )
244 .into_response();
245 }
246
247 for uri in &req.redirect_uris {
249 if uri.contains('#') {
251 return (
252 StatusCode::BAD_REQUEST,
253 Json(serde_json::json!({
254 "error": "invalid_redirect_uri",
255 "error_description": "redirect_uri must not contain a fragment"
256 })),
257 )
258 .into_response();
259 }
260 let is_localhost = uri.starts_with("http://localhost")
262 || uri.starts_with("http://127.0.0.1")
263 || uri.starts_with("http://[::1]");
264 let is_https = uri.starts_with("https://");
265 if !is_localhost && !is_https {
266 return (
267 StatusCode::BAD_REQUEST,
268 Json(serde_json::json!({
269 "error": "invalid_redirect_uri",
270 "error_description": "redirect_uri must use HTTPS for non-localhost URIs"
271 })),
272 )
273 .into_response();
274 }
275 }
276
277 let client_id = Uuid::new_v4().to_string();
278 let auth_method = req.token_endpoint_auth_method.as_deref().unwrap_or("none");
279
280 let result = sqlx::query!(
281 "INSERT INTO forge_oauth_clients (client_id, client_name, redirect_uris, token_endpoint_auth_method) \
282 VALUES ($1, $2, $3, $4)",
283 &client_id,
284 req.client_name as _,
285 &req.redirect_uris,
286 auth_method,
287 )
288 .execute(&state.pool)
289 .await;
290
291 if let Err(e) = result {
292 tracing::error!("Failed to register OAuth client: {e}");
293 return (
294 StatusCode::INTERNAL_SERVER_ERROR,
295 Json(serde_json::json!({
296 "error": "server_error",
297 "error_description": "Failed to register client"
298 })),
299 )
300 .into_response();
301 }
302
303 let grant_types = if req.grant_types.is_empty() {
304 vec!["authorization_code".into()]
305 } else {
306 req.grant_types
307 };
308
309 (
310 StatusCode::CREATED,
311 Json(RegisterResponse {
312 client_id,
313 client_name: req.client_name,
314 redirect_uris: req.redirect_uris,
315 grant_types,
316 token_endpoint_auth_method: auth_method.to_string(),
317 }),
318 )
319 .into_response()
320}
321
322#[derive(Deserialize)]
325pub struct AuthorizeQuery {
326 pub client_id: String,
327 pub redirect_uri: String,
328 pub code_challenge: String,
329 #[serde(default = "default_s256")]
330 pub code_challenge_method: String,
331 pub state: Option<String>,
332 pub scope: Option<String>,
333 pub response_type: Option<String>,
334}
335
336fn default_s256() -> String {
337 CHALLENGE_METHOD_S256.into()
338}
339
340pub async fn oauth_authorize_get(
341 headers: HeaderMap,
342 Query(params): Query<AuthorizeQuery>,
343 State(state): State<Arc<OAuthState>>,
344) -> Response {
345 let client = sqlx::query!(
347 "SELECT client_id, client_name, redirect_uris FROM forge_oauth_clients WHERE client_id = $1",
348 ¶ms.client_id,
349 )
350 .fetch_optional(&state.pool)
351 .await;
352
353 let (_, client_name, redirect_uris) = match client {
354 Ok(Some(c)) => (c.client_id, c.client_name, c.redirect_uris),
355 Ok(None) => {
356 return (
357 StatusCode::BAD_REQUEST,
358 Json(serde_json::json!({
359 "error": "invalid_client",
360 "error_description": "Unknown client_id"
361 })),
362 )
363 .into_response();
364 }
365 Err(e) => {
366 tracing::error!("OAuth client lookup failed: {e}");
367 return (
368 StatusCode::INTERNAL_SERVER_ERROR,
369 Json(serde_json::json!({
370 "error": "server_error"
371 })),
372 )
373 .into_response();
374 }
375 };
376
377 if !validate_redirect_uri(¶ms.redirect_uri, &redirect_uris) {
379 return (
380 StatusCode::BAD_REQUEST,
381 Json(serde_json::json!({
382 "error": "invalid_redirect_uri",
383 "error_description": "redirect_uri does not match any registered URI"
384 })),
385 )
386 .into_response();
387 }
388
389 if params.code_challenge_method != CHALLENGE_METHOD_S256 {
390 return (
391 StatusCode::BAD_REQUEST,
392 Json(serde_json::json!({
393 "error": "invalid_request",
394 "error_description": "Only S256 code_challenge_method is supported"
395 })),
396 )
397 .into_response();
398 }
399
400 let session_subject = extract_cookie(&headers, "forge_session")
402 .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
403 let has_session = session_subject.is_some();
404
405 let csrf_token = oauth::generate_random_token();
407 state.store_csrf(&csrf_token).await;
408
409 let auth_mode = if has_session {
410 "session" } else if state.auth_is_hmac {
412 "hmac" } else {
414 "external" };
416 let display_name = client_name.as_deref().unwrap_or(¶ms.client_id);
417
418 let html = AUTHORIZE_PAGE
419 .replace("{{app_name}}", &html_escape(&state.project_name))
420 .replace("{{client_name}}", &html_escape(display_name))
421 .replace("{{csrf_token}}", &csrf_token)
422 .replace("{{client_id}}", &html_escape(¶ms.client_id))
423 .replace("{{redirect_uri}}", &html_escape(¶ms.redirect_uri))
424 .replace("{{code_challenge}}", &html_escape(¶ms.code_challenge))
425 .replace(
426 "{{code_challenge_method}}",
427 &html_escape(¶ms.code_challenge_method),
428 )
429 .replace(
430 "{{state}}",
431 &html_escape(params.state.as_deref().unwrap_or("")),
432 )
433 .replace(
434 "{{scope}}",
435 &html_escape(params.scope.as_deref().unwrap_or("")),
436 )
437 .replace("{{auth_mode}}", &html_escape(auth_mode))
438 .replace("{{authorize_url}}", "/_api/oauth/authorize")
439 .replace("{{error_message}}", "");
440
441 let mut response = (StatusCode::OK, Html(html)).into_response();
442 response
444 .headers_mut()
445 .insert("X-Frame-Options", HeaderValue::from_static("DENY"));
446 response.headers_mut().insert(
447 "Content-Security-Policy",
448 HeaderValue::from_static("frame-ancestors 'none'"),
449 );
450 let csrf_secure_flag = if is_https(&headers) { "; Secure" } else { "" };
452 let cookie = format!(
453 "forge_oauth_csrf={csrf_token}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=600{csrf_secure_flag}"
454 );
455 if let Ok(cookie_val) = HeaderValue::from_str(&cookie) {
456 response
457 .headers_mut()
458 .insert(header::SET_COOKIE, cookie_val);
459 }
460 response
461}
462
463#[derive(Deserialize)]
464pub struct AuthorizeForm {
465 pub csrf_token: String,
466 pub client_id: String,
467 pub redirect_uri: String,
468 pub code_challenge: String,
469 pub code_challenge_method: String,
470 pub state: Option<String>,
471 pub scope: Option<String>,
472 pub response_type: Option<String>,
473 pub token: Option<String>,
475 pub email: Option<String>,
477 pub password: Option<String>,
478}
479
480pub async fn oauth_authorize_post(
481 headers: HeaderMap,
482 State(state): State<Arc<OAuthState>>,
483 axum::Form(form): axum::Form<AuthorizeForm>,
484) -> Response {
485 let csrf_from_cookie = extract_cookie(&headers, "forge_oauth_csrf");
487 let csrf_valid = if let Some(cookie_csrf) = csrf_from_cookie {
488 cookie_csrf == form.csrf_token && state.validate_csrf(&form.csrf_token).await
489 } else {
490 false
491 };
492 if !csrf_valid {
493 return (
494 StatusCode::FORBIDDEN,
495 Json(serde_json::json!({
496 "error": "csrf_validation_failed",
497 "error_description": "Invalid or expired CSRF token. Please try again."
498 })),
499 )
500 .into_response();
501 }
502
503 let ip = client_ip(&headers);
505 let rate_key = format!("oauth_login:{ip}");
506
507 let client = sqlx::query!(
509 "SELECT redirect_uris FROM forge_oauth_clients WHERE client_id = $1",
510 &form.client_id,
511 )
512 .fetch_optional(&state.pool)
513 .await;
514
515 let redirect_uris = match client {
516 Ok(Some(c)) => c.redirect_uris,
517 _ => {
518 return (
519 StatusCode::BAD_REQUEST,
520 Json(serde_json::json!({
521 "error": "invalid_client"
522 })),
523 )
524 .into_response();
525 }
526 };
527
528 if !validate_redirect_uri(&form.redirect_uri, &redirect_uris) {
529 return (
530 StatusCode::BAD_REQUEST,
531 Json(serde_json::json!({
532 "error": "invalid_redirect_uri"
533 })),
534 )
535 .into_response();
536 }
537
538 let user_id: Uuid;
540
541 let session_subject = extract_cookie(&headers, "forge_session")
542 .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
543
544 if let Some(subject) = session_subject {
545 user_id = subject.parse::<Uuid>().unwrap_or_else(|_| {
548 use sha2::Digest;
550 let hash: [u8; 32] = sha2::Sha256::digest(subject.as_bytes()).into();
551 let mut bytes = [0u8; 16];
552 bytes.copy_from_slice(&hash[..16]);
553 Uuid::from_bytes(bytes)
554 });
555 } else if let Some(token) = &form.token {
556 match state.auth_middleware.validate_token_async(token).await {
558 Ok(claims) => {
559 user_id = claims
560 .user_id()
561 .ok_or(())
562 .map_err(|_| ())
563 .unwrap_or_default();
564 if user_id.is_nil() {
565 return authorize_error_redirect(
566 &form.redirect_uri,
567 form.state.as_deref(),
568 "access_denied",
569 "Invalid user identity in token",
570 );
571 }
572 }
573 Err(_) => {
574 return authorize_error_redirect(
575 &form.redirect_uri,
576 form.state.as_deref(),
577 "access_denied",
578 "Invalid or expired token. Please log in again.",
579 );
580 }
581 }
582 } else if let (Some(email), Some(password)) = (&form.email, &form.password) {
583 if !state.auth_is_hmac {
585 return authorize_error_redirect(
586 &form.redirect_uri,
587 form.state.as_deref(),
588 "access_denied",
589 "Direct login not supported with external auth provider",
590 );
591 }
592
593 if !state
594 .rate_limiter
595 .check(&rate_key, LOGIN_FAIL_RATE_LIMIT)
596 .await
597 {
598 return authorize_error_redirect(
599 &form.redirect_uri,
600 form.state.as_deref(),
601 "access_denied",
602 "Too many login attempts. Please try again later.",
603 );
604 }
605
606 let row = sqlx::query!(
608 "SELECT id, password_hash, role::TEXT FROM users WHERE email = $1",
609 email,
610 )
611 .fetch_optional(&state.pool)
612 .await;
613
614 const DUMMY_HASH: &str = "$2b$10$x5F0VyTQ6qjX5YKr.WPmXuGNQzGqGN1pYnHvMBRz5bFm3VUSqJGi";
618 let (found_id, hash) = match &row {
619 Ok(Some(r)) if r.password_hash.is_some() => {
620 (Some(r.id), r.password_hash.as_deref().unwrap_or(DUMMY_HASH))
621 }
622 _ => (None, DUMMY_HASH),
623 };
624 let password_valid = bcrypt::verify(password, hash).unwrap_or(false);
625 if password_valid {
626 if let Some(id) = found_id {
627 user_id = id;
628 } else {
629 return authorize_error_redirect(
630 &form.redirect_uri,
631 form.state.as_deref(),
632 "access_denied",
633 "Invalid email or password",
634 );
635 }
636 } else {
637 return authorize_error_redirect(
638 &form.redirect_uri,
639 form.state.as_deref(),
640 "access_denied",
641 "Invalid email or password",
642 );
643 }
644 } else {
645 return (
646 StatusCode::BAD_REQUEST,
647 Json(serde_json::json!({
648 "error": "invalid_request",
649 "error_description": "Must provide either a token or email/password"
650 })),
651 )
652 .into_response();
653 }
654
655 let code = oauth::generate_random_token();
657 let expires_at = Utc::now() + chrono::Duration::seconds(AUTH_CODE_TTL_SECS);
658 let scopes: Vec<String> = form
659 .scope
660 .as_deref()
661 .map(|s| s.split_whitespace().map(String::from).collect())
662 .unwrap_or_default();
663
664 let result = sqlx::query!(
665 "INSERT INTO forge_oauth_codes \
666 (code, client_id, user_id, redirect_uri, code_challenge, code_challenge_method, scopes, expires_at) \
667 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
668 &code,
669 &form.client_id,
670 user_id,
671 &form.redirect_uri,
672 &form.code_challenge,
673 &form.code_challenge_method,
674 &scopes,
675 expires_at,
676 )
677 .execute(&state.pool)
678 .await;
679
680 if let Err(e) = result {
681 tracing::error!("Failed to store authorization code: {e}");
682 return authorize_error_redirect(
683 &form.redirect_uri,
684 form.state.as_deref(),
685 "server_error",
686 "Failed to generate authorization code",
687 );
688 }
689
690 let mut redirect_url = format!("{}?code={}", form.redirect_uri, urlencoding(&code));
692 if let Some(st) = &form.state {
693 redirect_url.push_str(&format!("&state={}", urlencoding(st)));
694 }
695
696 let mut response = Redirect::to(&redirect_url).into_response();
697 response
698 .headers_mut()
699 .insert("Referrer-Policy", HeaderValue::from_static("no-referrer"));
700
701 let cookie_value = super::auth::sign_session_cookie(&user_id.to_string(), &state.jwt_secret);
705 let secure_flag = if is_https(&headers) { "; Secure" } else { "" };
706 let session_cookie = format!(
707 "forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=86400{secure_flag}"
708 );
709 if let Ok(val) = HeaderValue::from_str(&session_cookie) {
710 response.headers_mut().append(header::SET_COOKIE, val);
711 }
712
713 response
714}
715
716#[derive(Deserialize)]
719pub struct TokenRequest {
720 pub grant_type: String,
721 pub code: Option<String>,
722 pub redirect_uri: Option<String>,
723 pub code_verifier: Option<String>,
724 pub client_id: Option<String>,
725 pub refresh_token: Option<String>,
726}
727
728#[derive(Serialize)]
729pub struct TokenResponse {
730 pub access_token: String,
731 pub token_type: String,
732 pub expires_in: i64,
733 pub refresh_token: String,
734}
735
736pub async fn oauth_token(
739 State(state): State<Arc<OAuthState>>,
740 headers: HeaderMap,
741 body: axum::body::Bytes,
742) -> Response {
743 let content_type = headers
744 .get(header::CONTENT_TYPE)
745 .and_then(|v| v.to_str().ok())
746 .unwrap_or("");
747
748 let req: TokenRequest = if content_type.starts_with("application/json") {
749 match serde_json::from_slice(&body) {
750 Ok(r) => r,
751 Err(e) => return token_error("invalid_request", &format!("Invalid JSON: {e}")),
752 }
753 } else {
754 match serde_urlencoded::from_bytes(&body) {
756 Ok(r) => r,
757 Err(e) => return token_error("invalid_request", &format!("Invalid form data: {e}")),
758 }
759 };
760
761 match req.grant_type.as_str() {
762 "authorization_code" => handle_code_exchange(&state, &req).await,
763 "refresh_token" => handle_refresh(&state, &req).await,
764 _ => (
765 StatusCode::BAD_REQUEST,
766 Json(serde_json::json!({
767 "error": "unsupported_grant_type"
768 })),
769 )
770 .into_response(),
771 }
772}
773
774async fn handle_code_exchange(state: &OAuthState, req: &TokenRequest) -> Response {
775 let code = match &req.code {
776 Some(c) => c,
777 None => return token_error("invalid_request", "code is required"),
778 };
779 let code_verifier = match &req.code_verifier {
780 Some(v) => v,
781 None => return token_error("invalid_request", "code_verifier is required"),
782 };
783 let redirect_uri = match &req.redirect_uri {
784 Some(r) => r,
785 None => return token_error("invalid_request", "redirect_uri is required"),
786 };
787 let client_id = match &req.client_id {
788 Some(c) => c,
789 None => return token_error("invalid_request", "client_id is required"),
790 };
791
792 let row = sqlx::query!(
794 "UPDATE forge_oauth_codes SET used_at = now() \
795 WHERE code = $1 AND used_at IS NULL \
796 RETURNING client_id, user_id, redirect_uri, code_challenge, code_challenge_method, expires_at",
797 code,
798 )
799 .fetch_optional(&state.pool)
800 .await;
801
802 let (
803 stored_client_id,
804 user_id,
805 stored_redirect,
806 stored_challenge,
807 challenge_method,
808 expires_at,
809 ) = match row {
810 Ok(Some(r)) => (
811 r.client_id,
812 r.user_id,
813 r.redirect_uri,
814 r.code_challenge,
815 r.code_challenge_method,
816 r.expires_at,
817 ),
818 Ok(None) => {
819 return token_error(
820 "invalid_grant",
821 "Invalid or already used authorization code",
822 );
823 }
824 Err(e) => {
825 tracing::error!("Failed to exchange authorization code: {e}");
826 return token_error("server_error", "Failed to exchange code");
827 }
828 };
829
830 if Utc::now() > expires_at {
832 return token_error("invalid_grant", "Authorization code has expired");
833 }
834
835 if *client_id != stored_client_id {
837 return token_error("invalid_grant", "client_id does not match");
838 }
839
840 if *redirect_uri != stored_redirect {
842 return token_error("invalid_grant", "redirect_uri does not match");
843 }
844
845 if challenge_method != CHALLENGE_METHOD_S256 {
846 return token_error("invalid_request", "Unsupported code_challenge_method");
847 }
848 if !forge_core::oauth::pkce::verify_s256(code_verifier, &stored_challenge) {
849 return token_error("invalid_grant", "PKCE verification failed");
850 }
851
852 let access_ttl = state.access_token_ttl_secs;
853 let refresh_ttl = state.refresh_token_ttl_days;
854
855 let pair = forge_core::auth::tokens::issue_token_pair_with_client(
856 &state.pool,
857 user_id,
858 &["user"],
859 access_ttl,
860 refresh_ttl,
861 Some(client_id),
862 mcp_token_issuer(state.token_issuer.clone()),
863 )
864 .await;
865
866 match pair {
867 Ok(pair) => (
868 StatusCode::OK,
869 Json(TokenResponse {
870 access_token: pair.access_token,
871 token_type: "Bearer".into(),
872 expires_in: access_ttl,
873 refresh_token: pair.refresh_token,
874 }),
875 )
876 .into_response(),
877 Err(e) => {
878 tracing::error!("Failed to issue token pair: {e}");
879 token_error("server_error", "Failed to issue tokens")
880 }
881 }
882}
883
884async fn handle_refresh(state: &OAuthState, req: &TokenRequest) -> Response {
885 let refresh_token = match &req.refresh_token {
886 Some(t) => t,
887 None => return token_error("invalid_request", "refresh_token is required"),
888 };
889 let client_id = req.client_id.as_deref();
890
891 let access_ttl = state.access_token_ttl_secs;
892 let refresh_ttl = state.refresh_token_ttl_days;
893
894 let pair = forge_core::auth::tokens::rotate_refresh_token_with_client(
895 &state.pool,
896 refresh_token,
897 &["user"],
898 access_ttl,
899 refresh_ttl,
900 client_id,
901 mcp_token_issuer(state.token_issuer.clone()),
902 )
903 .await;
904
905 match pair {
906 Ok(pair) => (
907 StatusCode::OK,
908 Json(TokenResponse {
909 access_token: pair.access_token,
910 token_type: "Bearer".into(),
911 expires_in: access_ttl,
912 refresh_token: pair.refresh_token,
913 }),
914 )
915 .into_response(),
916 Err(_) => token_error("invalid_grant", "Invalid or expired refresh token"),
917 }
918}
919
920fn mcp_token_issuer(
924 issuer: Arc<dyn forge_core::TokenIssuer>,
925) -> impl FnOnce(Uuid, &[&str], i64) -> forge_core::Result<String> {
926 move |uid, roles, ttl| {
927 let claims = Claims::builder()
928 .subject(uid)
929 .roles(roles.iter().map(|s| s.to_string()).collect())
930 .claim("aud".to_string(), serde_json::json!(MCP_AUDIENCE))
931 .duration_secs(ttl)
932 .build()
933 .map_err(forge_core::ForgeError::Internal)?;
934 issuer.sign(&claims)
935 }
936}
937
938fn is_https(headers: &HeaderMap) -> bool {
939 headers
940 .get("x-forwarded-proto")
941 .and_then(|v| v.to_str().ok())
942 .map(|s| s == "https")
943 .unwrap_or(false)
944}
945
946fn token_error(error: &str, description: &str) -> Response {
947 (
948 StatusCode::BAD_REQUEST,
949 Json(serde_json::json!({
950 "error": error,
951 "error_description": description
952 })),
953 )
954 .into_response()
955}
956
957fn authorize_error_redirect(
958 redirect_uri: &str,
959 state: Option<&str>,
960 error: &str,
961 description: &str,
962) -> Response {
963 let mut url = format!(
964 "{}?error={}&error_description={}",
965 redirect_uri,
966 urlencoding(error),
967 urlencoding(description),
968 );
969 if let Some(st) = state {
970 url.push_str(&format!("&state={}", urlencoding(st)));
971 }
972 Redirect::to(&url).into_response()
973}
974
975fn base_url_from_headers(headers: &HeaderMap) -> String {
976 let host = headers
977 .get("host")
978 .and_then(|v| v.to_str().ok())
979 .unwrap_or("localhost:9081");
980
981 let scheme = headers
982 .get("x-forwarded-proto")
983 .and_then(|v| v.to_str().ok())
984 .unwrap_or("http");
985
986 format!("{scheme}://{host}")
987}
988
989fn client_ip(headers: &HeaderMap) -> String {
990 headers
991 .get("x-forwarded-for")
992 .and_then(|v| v.to_str().ok())
993 .and_then(|s| s.split(',').next())
994 .map(|s| s.trim().to_string())
995 .or_else(|| {
996 headers
997 .get("x-real-ip")
998 .and_then(|v| v.to_str().ok())
999 .map(String::from)
1000 })
1001 .unwrap_or_else(|| "unknown".to_string())
1002}
1003
1004fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
1005 headers
1006 .get(header::COOKIE)
1007 .and_then(|v| v.to_str().ok())
1008 .and_then(|cookies| {
1009 cookies.split(';').map(|c| c.trim()).find_map(|c| {
1010 let (k, v) = c.split_once('=')?;
1011 if k == name { Some(v.to_string()) } else { None }
1012 })
1013 })
1014}
1015
1016fn html_escape(s: &str) -> String {
1017 s.replace('&', "&")
1018 .replace('<', "<")
1019 .replace('>', ">")
1020 .replace('"', """)
1021 .replace('\'', "'")
1022}
1023
1024fn urlencoding(s: &str) -> String {
1025 let mut result = String::with_capacity(s.len());
1027 for b in s.bytes() {
1028 match b {
1029 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1030 result.push(b as char);
1031 }
1032 _ => {
1033 result.push_str(&format!("%{b:02X}"));
1034 }
1035 }
1036 }
1037 result
1038}