1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use axum::Json;
11use axum::extract::{Query, State};
12use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
13use axum::response::{Html, IntoResponse, Redirect, Response};
14use chrono::Utc;
15use forge_core::auth::Claims;
16use forge_core::oauth::{self, validate_redirect_uri};
17use serde::{Deserialize, Serialize};
18use tokio::sync::RwLock;
19use uuid::Uuid;
20
21use super::auth::AuthMiddleware;
22
23const AUTHORIZE_PAGE: &str = include_str!("oauth_authorize.html");
24const AUTH_CODE_TTL_SECS: i64 = 60;
25const MAX_REGISTERED_CLIENTS: i64 = 1000;
26const CHALLENGE_METHOD_S256: &str = "S256";
27const MCP_AUDIENCE: &str = "forge:mcp";
28
29const REGISTER_RATE_LIMIT: u32 = 10; const LOGIN_FAIL_RATE_LIMIT: u32 = 5; const RATE_WINDOW_SECS: u64 = 60;
33const RATE_CLEANUP_THRESHOLD: usize = 100;
34
35#[derive(Clone, Default)]
37struct OAuthRateLimiter {
38 buckets: Arc<RwLock<HashMap<String, (u32, Instant)>>>,
39}
40
41impl OAuthRateLimiter {
42 async fn check(&self, key: &str, limit: u32) -> bool {
43 let mut buckets = self.buckets.write().await;
44 let now = Instant::now();
45 let window = Duration::from_secs(RATE_WINDOW_SECS);
46
47 if buckets.len() > RATE_CLEANUP_THRESHOLD {
49 buckets.retain(|_, (_, ts)| now.duration_since(*ts) <= window);
50 }
51
52 let entry = buckets.entry(key.to_string()).or_insert((0, now));
53 if now.duration_since(entry.1) > window {
54 *entry = (1, now);
55 return true;
56 }
57 if entry.0 >= limit {
58 return false;
59 }
60 entry.0 += 1;
61 true
62 }
63}
64
65#[derive(Clone)]
67pub struct OAuthState {
68 pool: sqlx::PgPool,
69 auth_middleware: Arc<AuthMiddleware>,
70 token_issuer: Arc<dyn forge_core::TokenIssuer>,
71 access_token_ttl_secs: i64,
72 refresh_token_ttl_days: i64,
73 auth_is_hmac: bool,
74 project_name: String,
75 jwt_secret: String,
76 rate_limiter: OAuthRateLimiter,
77 csrf_tokens: Arc<RwLock<HashMap<String, Instant>>>,
79}
80
81impl OAuthState {
82 #[allow(clippy::too_many_arguments)]
83 pub fn new(
84 pool: sqlx::PgPool,
85 auth_middleware: Arc<AuthMiddleware>,
86 token_issuer: Arc<dyn forge_core::TokenIssuer>,
87 access_token_ttl_secs: i64,
88 refresh_token_ttl_days: i64,
89 auth_is_hmac: bool,
90 project_name: String,
91 jwt_secret: String,
92 ) -> Self {
93 Self {
94 pool,
95 auth_middleware,
96 token_issuer,
97 access_token_ttl_secs,
98 refresh_token_ttl_days,
99 auth_is_hmac,
100 project_name,
101 jwt_secret,
102 rate_limiter: OAuthRateLimiter::default(),
103 csrf_tokens: Arc::new(RwLock::new(HashMap::new())),
104 }
105 }
106
107 async fn store_csrf(&self, token: &str) {
108 let mut tokens = self.csrf_tokens.write().await;
109 let now = Instant::now();
110 let expiry = now + Duration::from_secs(600); tokens.insert(token.to_string(), expiry);
112 if tokens.len() > RATE_CLEANUP_THRESHOLD {
114 tokens.retain(|_, exp| *exp > now);
115 }
116 }
117
118 async fn validate_csrf(&self, token: &str) -> bool {
119 let mut tokens = self.csrf_tokens.write().await;
120 if let Some(expiry) = tokens.remove(token) {
121 expiry > Instant::now()
122 } else {
123 false
124 }
125 }
126}
127
128#[derive(Serialize)]
131pub struct AuthorizationServerMetadata {
132 issuer: String,
133 authorization_endpoint: String,
134 token_endpoint: String,
135 registration_endpoint: String,
136 response_types_supported: Vec<String>,
137 grant_types_supported: Vec<String>,
138 code_challenge_methods_supported: Vec<String>,
139 token_endpoint_auth_methods_supported: Vec<String>,
140}
141
142pub async fn well_known_oauth_metadata(
143 headers: HeaderMap,
144 State(_state): State<Arc<OAuthState>>,
145) -> Json<AuthorizationServerMetadata> {
146 let base = base_url_from_headers(&headers);
147 Json(AuthorizationServerMetadata {
148 issuer: base.clone(),
149 authorization_endpoint: format!("{base}/_api/oauth/authorize"),
150 token_endpoint: format!("{base}/_api/oauth/token"),
151 registration_endpoint: format!("{base}/_api/oauth/register"),
152 response_types_supported: vec!["code".into()],
153 grant_types_supported: vec!["authorization_code".into(), "refresh_token".into()],
154 code_challenge_methods_supported: vec![CHALLENGE_METHOD_S256.into()],
155 token_endpoint_auth_methods_supported: vec!["none".into()],
156 })
157}
158
159#[derive(Serialize)]
160pub struct ProtectedResourceMetadata {
161 resource: String,
162 authorization_servers: Vec<String>,
163}
164
165pub async fn well_known_resource_metadata(
166 headers: HeaderMap,
167 State(_state): State<Arc<OAuthState>>,
168) -> Json<ProtectedResourceMetadata> {
169 let base = base_url_from_headers(&headers);
170 Json(ProtectedResourceMetadata {
171 resource: base.clone(),
172 authorization_servers: vec![base],
173 })
174}
175
176#[derive(Deserialize)]
179pub struct RegisterRequest {
180 pub client_name: Option<String>,
181 pub redirect_uris: Vec<String>,
182 #[serde(default)]
183 pub grant_types: Vec<String>,
184 #[serde(default)]
185 pub token_endpoint_auth_method: Option<String>,
186}
187
188#[derive(Serialize)]
189pub struct RegisterResponse {
190 pub client_id: String,
191 pub client_name: Option<String>,
192 pub redirect_uris: Vec<String>,
193 pub grant_types: Vec<String>,
194 pub token_endpoint_auth_method: String,
195}
196
197pub async fn oauth_register(
198 headers: HeaderMap,
199 State(state): State<Arc<OAuthState>>,
200 Json(req): Json<RegisterRequest>,
201) -> Response {
202 let ip = client_ip(&headers);
203 let rate_key = format!("oauth_register:{ip}");
204 if !state
205 .rate_limiter
206 .check(&rate_key, REGISTER_RATE_LIMIT)
207 .await
208 {
209 return (
210 StatusCode::TOO_MANY_REQUESTS,
211 Json(serde_json::json!({
212 "error": "too_many_requests",
213 "error_description": "Rate limit exceeded for client registration"
214 })),
215 )
216 .into_response();
217 }
218
219 let count: i64 = sqlx::query_scalar!("SELECT COUNT(*) FROM forge_oauth_clients")
221 .fetch_one(&state.pool)
222 .await
223 .unwrap_or(Some(0))
224 .unwrap_or(0);
225 if count >= MAX_REGISTERED_CLIENTS {
226 return (
227 StatusCode::BAD_REQUEST,
228 Json(serde_json::json!({
229 "error": "too_many_clients",
230 "error_description": "Maximum number of registered clients reached"
231 })),
232 )
233 .into_response();
234 }
235
236 if req.redirect_uris.is_empty() {
237 return (
238 StatusCode::BAD_REQUEST,
239 Json(serde_json::json!({
240 "error": "invalid_client_metadata",
241 "error_description": "redirect_uris is required"
242 })),
243 )
244 .into_response();
245 }
246
247 for uri in &req.redirect_uris {
249 if uri.contains('#') {
251 return (
252 StatusCode::BAD_REQUEST,
253 Json(serde_json::json!({
254 "error": "invalid_redirect_uri",
255 "error_description": "redirect_uri must not contain a fragment"
256 })),
257 )
258 .into_response();
259 }
260 let is_localhost = uri.starts_with("http://localhost")
262 || uri.starts_with("http://127.0.0.1")
263 || uri.starts_with("http://[::1]");
264 let is_https = uri.starts_with("https://");
265 if !is_localhost && !is_https {
266 return (
267 StatusCode::BAD_REQUEST,
268 Json(serde_json::json!({
269 "error": "invalid_redirect_uri",
270 "error_description": "redirect_uri must use HTTPS for non-localhost URIs"
271 })),
272 )
273 .into_response();
274 }
275 }
276
277 let client_id = Uuid::new_v4().to_string();
278 let auth_method = req.token_endpoint_auth_method.as_deref().unwrap_or("none");
279
280 let result = sqlx::query!(
281 "INSERT INTO forge_oauth_clients (client_id, client_name, redirect_uris, token_endpoint_auth_method) \
282 VALUES ($1, $2, $3, $4)",
283 &client_id,
284 req.client_name as _,
285 &req.redirect_uris,
286 auth_method,
287 )
288 .execute(&state.pool)
289 .await;
290
291 if let Err(e) = result {
292 tracing::error!("Failed to register OAuth client: {e}");
293 return (
294 StatusCode::INTERNAL_SERVER_ERROR,
295 Json(serde_json::json!({
296 "error": "server_error",
297 "error_description": "Failed to register client"
298 })),
299 )
300 .into_response();
301 }
302
303 let grant_types = if req.grant_types.is_empty() {
304 vec!["authorization_code".into()]
305 } else {
306 req.grant_types
307 };
308
309 (
310 StatusCode::CREATED,
311 Json(RegisterResponse {
312 client_id,
313 client_name: req.client_name,
314 redirect_uris: req.redirect_uris,
315 grant_types,
316 token_endpoint_auth_method: auth_method.to_string(),
317 }),
318 )
319 .into_response()
320}
321
322#[derive(Deserialize)]
325pub struct AuthorizeQuery {
326 pub client_id: String,
327 pub redirect_uri: String,
328 pub code_challenge: String,
329 #[serde(default = "default_s256")]
330 pub code_challenge_method: String,
331 pub state: Option<String>,
332 pub scope: Option<String>,
333 pub response_type: Option<String>,
334}
335
336fn default_s256() -> String {
337 CHALLENGE_METHOD_S256.into()
338}
339
340pub async fn oauth_authorize_get(
341 headers: HeaderMap,
342 Query(params): Query<AuthorizeQuery>,
343 State(state): State<Arc<OAuthState>>,
344) -> Response {
345 let client = sqlx::query!(
347 "SELECT client_id, client_name, redirect_uris FROM forge_oauth_clients WHERE client_id = $1",
348 ¶ms.client_id,
349 )
350 .fetch_optional(&state.pool)
351 .await;
352
353 let (_, client_name, redirect_uris) = match client {
354 Ok(Some(c)) => (c.client_id, c.client_name, c.redirect_uris),
355 Ok(None) => {
356 return (
357 StatusCode::BAD_REQUEST,
358 Json(serde_json::json!({
359 "error": "invalid_client",
360 "error_description": "Unknown client_id"
361 })),
362 )
363 .into_response();
364 }
365 Err(e) => {
366 tracing::error!("OAuth client lookup failed: {e}");
367 return (
368 StatusCode::INTERNAL_SERVER_ERROR,
369 Json(serde_json::json!({
370 "error": "server_error"
371 })),
372 )
373 .into_response();
374 }
375 };
376
377 if !validate_redirect_uri(¶ms.redirect_uri, &redirect_uris) {
379 return (
380 StatusCode::BAD_REQUEST,
381 Json(serde_json::json!({
382 "error": "invalid_redirect_uri",
383 "error_description": "redirect_uri does not match any registered URI"
384 })),
385 )
386 .into_response();
387 }
388
389 if params.code_challenge_method != CHALLENGE_METHOD_S256 {
390 return (
391 StatusCode::BAD_REQUEST,
392 Json(serde_json::json!({
393 "error": "invalid_request",
394 "error_description": "Only S256 code_challenge_method is supported"
395 })),
396 )
397 .into_response();
398 }
399
400 let session_subject = extract_cookie(&headers, "forge_session")
402 .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
403 let has_session = session_subject.is_some();
404
405 let csrf_token = oauth::generate_random_token();
407 state.store_csrf(&csrf_token).await;
408
409 let auth_mode = if has_session {
410 "session" } else if state.auth_is_hmac {
412 "hmac" } else {
414 "external" };
416 let display_name = client_name.as_deref().unwrap_or(¶ms.client_id);
417
418 let html = AUTHORIZE_PAGE
419 .replace("{{app_name}}", &html_escape(&state.project_name))
420 .replace("{{client_name}}", &html_escape(display_name))
421 .replace("{{csrf_token}}", &csrf_token)
422 .replace("{{client_id}}", &html_escape(¶ms.client_id))
423 .replace("{{redirect_uri}}", &html_escape(¶ms.redirect_uri))
424 .replace("{{code_challenge}}", &html_escape(¶ms.code_challenge))
425 .replace(
426 "{{code_challenge_method}}",
427 &html_escape(¶ms.code_challenge_method),
428 )
429 .replace(
430 "{{state}}",
431 &html_escape(params.state.as_deref().unwrap_or("")),
432 )
433 .replace(
434 "{{scope}}",
435 &html_escape(params.scope.as_deref().unwrap_or("")),
436 )
437 .replace("{{auth_mode}}", auth_mode)
438 .replace("{{authorize_url}}", "/_api/oauth/authorize")
439 .replace("{{error_message}}", "");
440
441 let mut response = (StatusCode::OK, Html(html)).into_response();
442 response
444 .headers_mut()
445 .insert("X-Frame-Options", HeaderValue::from_static("DENY"));
446 response.headers_mut().insert(
447 "Content-Security-Policy",
448 HeaderValue::from_static("frame-ancestors 'none'"),
449 );
450 let csrf_secure_flag = if is_https(&headers) { "; Secure" } else { "" };
452 let cookie = format!(
453 "forge_oauth_csrf={csrf_token}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=600{csrf_secure_flag}"
454 );
455 if let Ok(cookie_val) = HeaderValue::from_str(&cookie) {
456 response
457 .headers_mut()
458 .insert(header::SET_COOKIE, cookie_val);
459 }
460 response
461}
462
463#[derive(Deserialize)]
464pub struct AuthorizeForm {
465 pub csrf_token: String,
466 pub client_id: String,
467 pub redirect_uri: String,
468 pub code_challenge: String,
469 pub code_challenge_method: String,
470 pub state: Option<String>,
471 pub scope: Option<String>,
472 pub response_type: Option<String>,
473 pub token: Option<String>,
475 pub email: Option<String>,
477 pub password: Option<String>,
478}
479
480pub async fn oauth_authorize_post(
481 headers: HeaderMap,
482 State(state): State<Arc<OAuthState>>,
483 axum::Form(form): axum::Form<AuthorizeForm>,
484) -> Response {
485 let csrf_from_cookie = extract_cookie(&headers, "forge_oauth_csrf");
487 let csrf_valid = if let Some(cookie_csrf) = csrf_from_cookie {
488 cookie_csrf == form.csrf_token && state.validate_csrf(&form.csrf_token).await
489 } else {
490 false
491 };
492 if !csrf_valid {
493 return (
494 StatusCode::FORBIDDEN,
495 Json(serde_json::json!({
496 "error": "csrf_validation_failed",
497 "error_description": "Invalid or expired CSRF token. Please try again."
498 })),
499 )
500 .into_response();
501 }
502
503 let ip = client_ip(&headers);
505 let rate_key = format!("oauth_login:{ip}");
506
507 let client = sqlx::query!(
509 "SELECT redirect_uris FROM forge_oauth_clients WHERE client_id = $1",
510 &form.client_id,
511 )
512 .fetch_optional(&state.pool)
513 .await;
514
515 let redirect_uris = match client {
516 Ok(Some(c)) => c.redirect_uris,
517 _ => {
518 return (
519 StatusCode::BAD_REQUEST,
520 Json(serde_json::json!({
521 "error": "invalid_client"
522 })),
523 )
524 .into_response();
525 }
526 };
527
528 if !validate_redirect_uri(&form.redirect_uri, &redirect_uris) {
529 return (
530 StatusCode::BAD_REQUEST,
531 Json(serde_json::json!({
532 "error": "invalid_redirect_uri"
533 })),
534 )
535 .into_response();
536 }
537
538 let user_id: Uuid;
540
541 let session_subject = extract_cookie(&headers, "forge_session")
542 .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
543
544 if let Some(subject) = session_subject {
545 user_id = subject.parse::<Uuid>().unwrap_or_else(|_| {
548 use sha2::Digest;
550 let hash: [u8; 32] = sha2::Sha256::digest(subject.as_bytes()).into();
551 let mut bytes = [0u8; 16];
552 bytes.copy_from_slice(&hash[..16]);
553 Uuid::from_bytes(bytes)
554 });
555 } else if let Some(token) = &form.token {
556 match state.auth_middleware.validate_token_async(token).await {
558 Ok(claims) => {
559 user_id = claims
560 .user_id()
561 .ok_or(())
562 .map_err(|_| ())
563 .unwrap_or_default();
564 if user_id.is_nil() {
565 return authorize_error_redirect(
566 &form.redirect_uri,
567 form.state.as_deref(),
568 "access_denied",
569 "Invalid user identity in token",
570 );
571 }
572 }
573 Err(_) => {
574 return authorize_error_redirect(
575 &form.redirect_uri,
576 form.state.as_deref(),
577 "access_denied",
578 "Invalid or expired token. Please log in again.",
579 );
580 }
581 }
582 } else if let (Some(email), Some(password)) = (&form.email, &form.password) {
583 if !state.auth_is_hmac {
585 return authorize_error_redirect(
586 &form.redirect_uri,
587 form.state.as_deref(),
588 "access_denied",
589 "Direct login not supported with external auth provider",
590 );
591 }
592
593 if !state
594 .rate_limiter
595 .check(&rate_key, LOGIN_FAIL_RATE_LIMIT)
596 .await
597 {
598 return authorize_error_redirect(
599 &form.redirect_uri,
600 form.state.as_deref(),
601 "access_denied",
602 "Too many login attempts. Please try again later.",
603 );
604 }
605
606 let row = sqlx::query!(
608 "SELECT id, password_hash, role::TEXT FROM users WHERE email = $1",
609 email,
610 )
611 .fetch_optional(&state.pool)
612 .await;
613
614 match row {
615 Ok(Some(r)) if r.password_hash.is_some() => {
616 match bcrypt::verify(
617 password,
618 r.password_hash.as_ref().expect("guarded by is_some check"),
619 ) {
620 Ok(true) => {
621 user_id = r.id;
622 }
623 _ => {
624 return authorize_error_redirect(
625 &form.redirect_uri,
626 form.state.as_deref(),
627 "access_denied",
628 "Invalid email or password",
629 );
630 }
631 }
632 }
633 _ => {
634 return authorize_error_redirect(
635 &form.redirect_uri,
636 form.state.as_deref(),
637 "access_denied",
638 "Invalid email or password",
639 );
640 }
641 }
642 } else {
643 return (
644 StatusCode::BAD_REQUEST,
645 Json(serde_json::json!({
646 "error": "invalid_request",
647 "error_description": "Must provide either a token or email/password"
648 })),
649 )
650 .into_response();
651 }
652
653 let code = oauth::generate_random_token();
655 let expires_at = Utc::now() + chrono::Duration::seconds(AUTH_CODE_TTL_SECS);
656 let scopes: Vec<String> = form
657 .scope
658 .as_deref()
659 .map(|s| s.split_whitespace().map(String::from).collect())
660 .unwrap_or_default();
661
662 let result = sqlx::query!(
663 "INSERT INTO forge_oauth_codes \
664 (code, client_id, user_id, redirect_uri, code_challenge, code_challenge_method, scopes, expires_at) \
665 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
666 &code,
667 &form.client_id,
668 user_id,
669 &form.redirect_uri,
670 &form.code_challenge,
671 &form.code_challenge_method,
672 &scopes,
673 expires_at,
674 )
675 .execute(&state.pool)
676 .await;
677
678 if let Err(e) = result {
679 tracing::error!("Failed to store authorization code: {e}");
680 return authorize_error_redirect(
681 &form.redirect_uri,
682 form.state.as_deref(),
683 "server_error",
684 "Failed to generate authorization code",
685 );
686 }
687
688 let mut redirect_url = format!("{}?code={}", form.redirect_uri, urlencoding(&code));
690 if let Some(st) = &form.state {
691 redirect_url.push_str(&format!("&state={}", urlencoding(st)));
692 }
693
694 let mut response = Redirect::to(&redirect_url).into_response();
695 response
696 .headers_mut()
697 .insert("Referrer-Policy", HeaderValue::from_static("no-referrer"));
698
699 let cookie_value = super::auth::sign_session_cookie(&user_id.to_string(), &state.jwt_secret);
703 let secure_flag = if is_https(&headers) { "; Secure" } else { "" };
704 let session_cookie = format!(
705 "forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=86400{secure_flag}"
706 );
707 if let Ok(val) = HeaderValue::from_str(&session_cookie) {
708 response.headers_mut().append(header::SET_COOKIE, val);
709 }
710
711 response
712}
713
714#[derive(Deserialize)]
717pub struct TokenRequest {
718 pub grant_type: String,
719 pub code: Option<String>,
720 pub redirect_uri: Option<String>,
721 pub code_verifier: Option<String>,
722 pub client_id: Option<String>,
723 pub refresh_token: Option<String>,
724}
725
726#[derive(Serialize)]
727pub struct TokenResponse {
728 pub access_token: String,
729 pub token_type: String,
730 pub expires_in: i64,
731 pub refresh_token: String,
732}
733
734pub async fn oauth_token(
737 State(state): State<Arc<OAuthState>>,
738 headers: HeaderMap,
739 body: axum::body::Bytes,
740) -> Response {
741 let content_type = headers
742 .get(header::CONTENT_TYPE)
743 .and_then(|v| v.to_str().ok())
744 .unwrap_or("");
745
746 let req: TokenRequest = if content_type.starts_with("application/json") {
747 match serde_json::from_slice(&body) {
748 Ok(r) => r,
749 Err(e) => return token_error("invalid_request", &format!("Invalid JSON: {e}")),
750 }
751 } else {
752 match serde_urlencoded::from_bytes(&body) {
754 Ok(r) => r,
755 Err(e) => return token_error("invalid_request", &format!("Invalid form data: {e}")),
756 }
757 };
758
759 match req.grant_type.as_str() {
760 "authorization_code" => handle_code_exchange(&state, &req).await,
761 "refresh_token" => handle_refresh(&state, &req).await,
762 _ => (
763 StatusCode::BAD_REQUEST,
764 Json(serde_json::json!({
765 "error": "unsupported_grant_type"
766 })),
767 )
768 .into_response(),
769 }
770}
771
772async fn handle_code_exchange(state: &OAuthState, req: &TokenRequest) -> Response {
773 let code = match &req.code {
774 Some(c) => c,
775 None => return token_error("invalid_request", "code is required"),
776 };
777 let code_verifier = match &req.code_verifier {
778 Some(v) => v,
779 None => return token_error("invalid_request", "code_verifier is required"),
780 };
781 let redirect_uri = match &req.redirect_uri {
782 Some(r) => r,
783 None => return token_error("invalid_request", "redirect_uri is required"),
784 };
785 let client_id = match &req.client_id {
786 Some(c) => c,
787 None => return token_error("invalid_request", "client_id is required"),
788 };
789
790 let row = sqlx::query!(
792 "UPDATE forge_oauth_codes SET used_at = now() \
793 WHERE code = $1 AND used_at IS NULL \
794 RETURNING client_id, user_id, redirect_uri, code_challenge, code_challenge_method, expires_at",
795 code,
796 )
797 .fetch_optional(&state.pool)
798 .await;
799
800 let (
801 stored_client_id,
802 user_id,
803 stored_redirect,
804 stored_challenge,
805 challenge_method,
806 expires_at,
807 ) = match row {
808 Ok(Some(r)) => (
809 r.client_id,
810 r.user_id,
811 r.redirect_uri,
812 r.code_challenge,
813 r.code_challenge_method,
814 r.expires_at,
815 ),
816 Ok(None) => {
817 return token_error(
818 "invalid_grant",
819 "Invalid or already used authorization code",
820 );
821 }
822 Err(e) => {
823 tracing::error!("Failed to exchange authorization code: {e}");
824 return token_error("server_error", "Failed to exchange code");
825 }
826 };
827
828 if Utc::now() > expires_at {
830 return token_error("invalid_grant", "Authorization code has expired");
831 }
832
833 if *client_id != stored_client_id {
835 return token_error("invalid_grant", "client_id does not match");
836 }
837
838 if *redirect_uri != stored_redirect {
840 return token_error("invalid_grant", "redirect_uri does not match");
841 }
842
843 if challenge_method != CHALLENGE_METHOD_S256 {
844 return token_error("invalid_request", "Unsupported code_challenge_method");
845 }
846 if !forge_core::oauth::pkce::verify_s256(code_verifier, &stored_challenge) {
847 return token_error("invalid_grant", "PKCE verification failed");
848 }
849
850 let access_ttl = state.access_token_ttl_secs;
851 let refresh_ttl = state.refresh_token_ttl_days;
852
853 let pair = forge_core::auth::tokens::issue_token_pair_with_client(
854 &state.pool,
855 user_id,
856 &["user"],
857 access_ttl,
858 refresh_ttl,
859 Some(client_id),
860 mcp_token_issuer(state.token_issuer.clone()),
861 )
862 .await;
863
864 match pair {
865 Ok(pair) => (
866 StatusCode::OK,
867 Json(TokenResponse {
868 access_token: pair.access_token,
869 token_type: "Bearer".into(),
870 expires_in: access_ttl,
871 refresh_token: pair.refresh_token,
872 }),
873 )
874 .into_response(),
875 Err(e) => {
876 tracing::error!("Failed to issue token pair: {e}");
877 token_error("server_error", "Failed to issue tokens")
878 }
879 }
880}
881
882async fn handle_refresh(state: &OAuthState, req: &TokenRequest) -> Response {
883 let refresh_token = match &req.refresh_token {
884 Some(t) => t,
885 None => return token_error("invalid_request", "refresh_token is required"),
886 };
887 let client_id = req.client_id.as_deref();
888
889 let access_ttl = state.access_token_ttl_secs;
890 let refresh_ttl = state.refresh_token_ttl_days;
891
892 let pair = forge_core::auth::tokens::rotate_refresh_token_with_client(
893 &state.pool,
894 refresh_token,
895 &["user"],
896 access_ttl,
897 refresh_ttl,
898 client_id,
899 mcp_token_issuer(state.token_issuer.clone()),
900 )
901 .await;
902
903 match pair {
904 Ok(pair) => (
905 StatusCode::OK,
906 Json(TokenResponse {
907 access_token: pair.access_token,
908 token_type: "Bearer".into(),
909 expires_in: access_ttl,
910 refresh_token: pair.refresh_token,
911 }),
912 )
913 .into_response(),
914 Err(_) => token_error("invalid_grant", "Invalid or expired refresh token"),
915 }
916}
917
918fn mcp_token_issuer(
922 issuer: Arc<dyn forge_core::TokenIssuer>,
923) -> impl FnOnce(Uuid, &[&str], i64) -> forge_core::Result<String> {
924 move |uid, roles, ttl| {
925 let claims = Claims::builder()
926 .subject(uid)
927 .roles(roles.iter().map(|s| s.to_string()).collect())
928 .claim("aud".to_string(), serde_json::json!(MCP_AUDIENCE))
929 .duration_secs(ttl)
930 .build()
931 .map_err(forge_core::ForgeError::Internal)?;
932 issuer.sign(&claims)
933 }
934}
935
936fn is_https(headers: &HeaderMap) -> bool {
937 headers
938 .get("x-forwarded-proto")
939 .and_then(|v| v.to_str().ok())
940 .map(|s| s == "https")
941 .unwrap_or(false)
942}
943
944fn token_error(error: &str, description: &str) -> Response {
945 (
946 StatusCode::BAD_REQUEST,
947 Json(serde_json::json!({
948 "error": error,
949 "error_description": description
950 })),
951 )
952 .into_response()
953}
954
955fn authorize_error_redirect(
956 redirect_uri: &str,
957 state: Option<&str>,
958 error: &str,
959 description: &str,
960) -> Response {
961 let mut url = format!(
962 "{}?error={}&error_description={}",
963 redirect_uri,
964 urlencoding(error),
965 urlencoding(description),
966 );
967 if let Some(st) = state {
968 url.push_str(&format!("&state={}", urlencoding(st)));
969 }
970 Redirect::to(&url).into_response()
971}
972
973fn base_url_from_headers(headers: &HeaderMap) -> String {
974 let host = headers
975 .get("host")
976 .and_then(|v| v.to_str().ok())
977 .unwrap_or("localhost:9081");
978
979 let scheme = headers
980 .get("x-forwarded-proto")
981 .and_then(|v| v.to_str().ok())
982 .unwrap_or("http");
983
984 format!("{scheme}://{host}")
985}
986
987fn client_ip(headers: &HeaderMap) -> String {
988 headers
989 .get("x-forwarded-for")
990 .and_then(|v| v.to_str().ok())
991 .and_then(|s| s.split(',').next())
992 .map(|s| s.trim().to_string())
993 .or_else(|| {
994 headers
995 .get("x-real-ip")
996 .and_then(|v| v.to_str().ok())
997 .map(String::from)
998 })
999 .unwrap_or_else(|| "unknown".to_string())
1000}
1001
1002fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
1003 headers
1004 .get(header::COOKIE)
1005 .and_then(|v| v.to_str().ok())
1006 .and_then(|cookies| {
1007 cookies.split(';').map(|c| c.trim()).find_map(|c| {
1008 let (k, v) = c.split_once('=')?;
1009 if k == name { Some(v.to_string()) } else { None }
1010 })
1011 })
1012}
1013
1014fn html_escape(s: &str) -> String {
1015 s.replace('&', "&")
1016 .replace('<', "<")
1017 .replace('>', ">")
1018 .replace('"', """)
1019 .replace('\'', "'")
1020}
1021
1022fn urlencoding(s: &str) -> String {
1023 let mut result = String::with_capacity(s.len());
1025 for b in s.bytes() {
1026 match b {
1027 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1028 result.push(b as char);
1029 }
1030 _ => {
1031 result.push_str(&format!("%{b:02X}"));
1032 }
1033 }
1034 }
1035 result
1036}