1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use axum::Json;
11use axum::extract::{Query, State};
12use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
13use axum::response::{Html, IntoResponse, Redirect, Response};
14use chrono::Utc;
15use forge_core::auth::Claims;
16use forge_core::oauth::{self, validate_redirect_uri};
17use serde::{Deserialize, Serialize};
18use tokio::sync::RwLock;
19use uuid::Uuid;
20
21use super::auth::AuthMiddleware;
22
23const AUTHORIZE_PAGE: &str = include_str!("oauth_authorize.html");
24const AUTH_CODE_TTL_SECS: i64 = 60;
25const MAX_REGISTERED_CLIENTS: i64 = 1000;
26const CHALLENGE_METHOD_S256: &str = "S256";
27const MCP_AUDIENCE: &str = "forge:mcp";
28
29const REGISTER_RATE_LIMIT: u32 = 10; const LOGIN_FAIL_RATE_LIMIT: u32 = 5; const RATE_WINDOW_SECS: u64 = 60;
33const RATE_CLEANUP_THRESHOLD: usize = 100;
34
35#[derive(Clone, Default)]
37struct OAuthRateLimiter {
38 buckets: Arc<RwLock<HashMap<String, (u32, Instant)>>>,
39}
40
41impl OAuthRateLimiter {
42 async fn check(&self, key: &str, limit: u32) -> bool {
43 let mut buckets = self.buckets.write().await;
44 let now = Instant::now();
45 let window = Duration::from_secs(RATE_WINDOW_SECS);
46
47 if buckets.len() > RATE_CLEANUP_THRESHOLD {
49 buckets.retain(|_, (_, ts)| now.duration_since(*ts) <= window);
50 }
51
52 let entry = buckets.entry(key.to_string()).or_insert((0, now));
53 if now.duration_since(entry.1) > window {
54 *entry = (1, now);
55 return true;
56 }
57 if entry.0 >= limit {
58 return false;
59 }
60 entry.0 += 1;
61 true
62 }
63}
64
65#[derive(Clone)]
67pub struct OAuthState {
68 pool: sqlx::PgPool,
69 auth_middleware: Arc<AuthMiddleware>,
70 token_issuer: Arc<dyn forge_core::TokenIssuer>,
71 access_token_ttl_secs: i64,
72 refresh_token_ttl_days: i64,
73 auth_is_hmac: bool,
74 project_name: String,
75 jwt_secret: String,
76 rate_limiter: OAuthRateLimiter,
77 csrf_tokens: Arc<RwLock<HashMap<String, Instant>>>,
79}
80
81impl OAuthState {
82 #[allow(clippy::too_many_arguments)]
83 pub fn new(
84 pool: sqlx::PgPool,
85 auth_middleware: Arc<AuthMiddleware>,
86 token_issuer: Arc<dyn forge_core::TokenIssuer>,
87 access_token_ttl_secs: i64,
88 refresh_token_ttl_days: i64,
89 auth_is_hmac: bool,
90 project_name: String,
91 jwt_secret: String,
92 ) -> Self {
93 Self {
94 pool,
95 auth_middleware,
96 token_issuer,
97 access_token_ttl_secs,
98 refresh_token_ttl_days,
99 auth_is_hmac,
100 project_name,
101 jwt_secret,
102 rate_limiter: OAuthRateLimiter::default(),
103 csrf_tokens: Arc::new(RwLock::new(HashMap::new())),
104 }
105 }
106
107 async fn store_csrf(&self, token: &str) {
108 let mut tokens = self.csrf_tokens.write().await;
109 let now = Instant::now();
110 let expiry = now + Duration::from_secs(600); tokens.insert(token.to_string(), expiry);
112 if tokens.len() > RATE_CLEANUP_THRESHOLD {
114 tokens.retain(|_, exp| *exp > now);
115 }
116 }
117
118 async fn validate_csrf(&self, token: &str) -> bool {
119 let mut tokens = self.csrf_tokens.write().await;
120 if let Some(expiry) = tokens.remove(token) {
121 expiry > Instant::now()
122 } else {
123 false
124 }
125 }
126}
127
128#[derive(Serialize)]
131pub struct AuthorizationServerMetadata {
132 issuer: String,
133 authorization_endpoint: String,
134 token_endpoint: String,
135 registration_endpoint: String,
136 response_types_supported: Vec<String>,
137 grant_types_supported: Vec<String>,
138 code_challenge_methods_supported: Vec<String>,
139 token_endpoint_auth_methods_supported: Vec<String>,
140}
141
142pub async fn well_known_oauth_metadata(
143 headers: HeaderMap,
144 State(_state): State<Arc<OAuthState>>,
145) -> Json<AuthorizationServerMetadata> {
146 let base = base_url_from_headers(&headers);
147 Json(AuthorizationServerMetadata {
148 issuer: base.clone(),
149 authorization_endpoint: format!("{base}/_api/oauth/authorize"),
150 token_endpoint: format!("{base}/_api/oauth/token"),
151 registration_endpoint: format!("{base}/_api/oauth/register"),
152 response_types_supported: vec!["code".into()],
153 grant_types_supported: vec!["authorization_code".into(), "refresh_token".into()],
154 code_challenge_methods_supported: vec![CHALLENGE_METHOD_S256.into()],
155 token_endpoint_auth_methods_supported: vec!["none".into()],
156 })
157}
158
159#[derive(Serialize)]
160pub struct ProtectedResourceMetadata {
161 resource: String,
162 authorization_servers: Vec<String>,
163}
164
165pub async fn well_known_resource_metadata(
166 headers: HeaderMap,
167 State(_state): State<Arc<OAuthState>>,
168) -> Json<ProtectedResourceMetadata> {
169 let base = base_url_from_headers(&headers);
170 Json(ProtectedResourceMetadata {
171 resource: base.clone(),
172 authorization_servers: vec![base],
173 })
174}
175
176#[derive(Deserialize)]
179pub struct RegisterRequest {
180 pub client_name: Option<String>,
181 pub redirect_uris: Vec<String>,
182 #[serde(default)]
183 pub grant_types: Vec<String>,
184 #[serde(default)]
185 pub token_endpoint_auth_method: Option<String>,
186}
187
188#[derive(Serialize)]
189pub struct RegisterResponse {
190 pub client_id: String,
191 pub client_name: Option<String>,
192 pub redirect_uris: Vec<String>,
193 pub grant_types: Vec<String>,
194 pub token_endpoint_auth_method: String,
195}
196
197pub async fn oauth_register(
198 headers: HeaderMap,
199 State(state): State<Arc<OAuthState>>,
200 Json(req): Json<RegisterRequest>,
201) -> Response {
202 let ip = client_ip(&headers);
203 let rate_key = format!("oauth_register:{ip}");
204 if !state
205 .rate_limiter
206 .check(&rate_key, REGISTER_RATE_LIMIT)
207 .await
208 {
209 return (
210 StatusCode::TOO_MANY_REQUESTS,
211 Json(serde_json::json!({
212 "error": "too_many_requests",
213 "error_description": "Rate limit exceeded for client registration"
214 })),
215 )
216 .into_response();
217 }
218
219 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM forge_oauth_clients")
221 .fetch_one(&state.pool)
222 .await
223 .unwrap_or(0);
224 if count >= MAX_REGISTERED_CLIENTS {
225 return (
226 StatusCode::BAD_REQUEST,
227 Json(serde_json::json!({
228 "error": "too_many_clients",
229 "error_description": "Maximum number of registered clients reached"
230 })),
231 )
232 .into_response();
233 }
234
235 if req.redirect_uris.is_empty() {
236 return (
237 StatusCode::BAD_REQUEST,
238 Json(serde_json::json!({
239 "error": "invalid_client_metadata",
240 "error_description": "redirect_uris is required"
241 })),
242 )
243 .into_response();
244 }
245
246 let client_id = Uuid::new_v4().to_string();
247 let auth_method = req.token_endpoint_auth_method.as_deref().unwrap_or("none");
248
249 let result = sqlx::query(
250 "INSERT INTO forge_oauth_clients (client_id, client_name, redirect_uris, token_endpoint_auth_method) \
251 VALUES ($1, $2, $3, $4)"
252 )
253 .bind(&client_id)
254 .bind(&req.client_name)
255 .bind(&req.redirect_uris)
256 .bind(auth_method)
257 .execute(&state.pool)
258 .await;
259
260 if let Err(e) = result {
261 tracing::error!("Failed to register OAuth client: {e}");
262 return (
263 StatusCode::INTERNAL_SERVER_ERROR,
264 Json(serde_json::json!({
265 "error": "server_error",
266 "error_description": "Failed to register client"
267 })),
268 )
269 .into_response();
270 }
271
272 let grant_types = if req.grant_types.is_empty() {
273 vec!["authorization_code".into()]
274 } else {
275 req.grant_types
276 };
277
278 (
279 StatusCode::CREATED,
280 Json(RegisterResponse {
281 client_id,
282 client_name: req.client_name,
283 redirect_uris: req.redirect_uris,
284 grant_types,
285 token_endpoint_auth_method: auth_method.to_string(),
286 }),
287 )
288 .into_response()
289}
290
291#[derive(Deserialize)]
294pub struct AuthorizeQuery {
295 pub client_id: String,
296 pub redirect_uri: String,
297 pub code_challenge: String,
298 #[serde(default = "default_s256")]
299 pub code_challenge_method: String,
300 pub state: Option<String>,
301 pub scope: Option<String>,
302 pub response_type: Option<String>,
303}
304
305fn default_s256() -> String {
306 CHALLENGE_METHOD_S256.into()
307}
308
309pub async fn oauth_authorize_get(
310 headers: HeaderMap,
311 Query(params): Query<AuthorizeQuery>,
312 State(state): State<Arc<OAuthState>>,
313) -> Response {
314 let client = sqlx::query_as::<_, (String, Option<String>, Vec<String>)>(
316 "SELECT client_id, client_name, redirect_uris FROM forge_oauth_clients WHERE client_id = $1"
317 )
318 .bind(¶ms.client_id)
319 .fetch_optional(&state.pool)
320 .await;
321
322 let (_, client_name, redirect_uris) = match client {
323 Ok(Some(c)) => c,
324 Ok(None) => {
325 return (
326 StatusCode::BAD_REQUEST,
327 Json(serde_json::json!({
328 "error": "invalid_client",
329 "error_description": "Unknown client_id"
330 })),
331 )
332 .into_response();
333 }
334 Err(e) => {
335 tracing::error!("OAuth client lookup failed: {e}");
336 return (
337 StatusCode::INTERNAL_SERVER_ERROR,
338 Json(serde_json::json!({
339 "error": "server_error"
340 })),
341 )
342 .into_response();
343 }
344 };
345
346 if !validate_redirect_uri(¶ms.redirect_uri, &redirect_uris) {
348 return (
349 StatusCode::BAD_REQUEST,
350 Json(serde_json::json!({
351 "error": "invalid_redirect_uri",
352 "error_description": "redirect_uri does not match any registered URI"
353 })),
354 )
355 .into_response();
356 }
357
358 if params.code_challenge_method != CHALLENGE_METHOD_S256 {
359 return (
360 StatusCode::BAD_REQUEST,
361 Json(serde_json::json!({
362 "error": "invalid_request",
363 "error_description": "Only S256 code_challenge_method is supported"
364 })),
365 )
366 .into_response();
367 }
368
369 let session_subject = extract_cookie(&headers, "forge_session")
371 .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
372 let has_session = session_subject.is_some();
373
374 let csrf_token = oauth::generate_random_token();
376 state.store_csrf(&csrf_token).await;
377
378 let auth_mode = if has_session {
379 "session" } else if state.auth_is_hmac {
381 "hmac" } else {
383 "external" };
385 let display_name = client_name.as_deref().unwrap_or(¶ms.client_id);
386
387 let html = AUTHORIZE_PAGE
388 .replace("{{app_name}}", &html_escape(&state.project_name))
389 .replace("{{client_name}}", &html_escape(display_name))
390 .replace("{{csrf_token}}", &csrf_token)
391 .replace("{{client_id}}", &html_escape(¶ms.client_id))
392 .replace("{{redirect_uri}}", &html_escape(¶ms.redirect_uri))
393 .replace("{{code_challenge}}", &html_escape(¶ms.code_challenge))
394 .replace(
395 "{{code_challenge_method}}",
396 &html_escape(¶ms.code_challenge_method),
397 )
398 .replace(
399 "{{state}}",
400 &html_escape(params.state.as_deref().unwrap_or("")),
401 )
402 .replace(
403 "{{scope}}",
404 &html_escape(params.scope.as_deref().unwrap_or("")),
405 )
406 .replace("{{auth_mode}}", auth_mode)
407 .replace("{{authorize_url}}", "/_api/oauth/authorize")
408 .replace("{{error_message}}", "");
409
410 let mut response = (StatusCode::OK, Html(html)).into_response();
411 response
413 .headers_mut()
414 .insert("X-Frame-Options", HeaderValue::from_static("DENY"));
415 response.headers_mut().insert(
416 "Content-Security-Policy",
417 HeaderValue::from_static("frame-ancestors 'none'"),
418 );
419 let cookie = format!(
421 "forge_oauth_csrf={csrf_token}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=600"
422 );
423 if let Ok(cookie_val) = HeaderValue::from_str(&cookie) {
424 response
425 .headers_mut()
426 .insert(header::SET_COOKIE, cookie_val);
427 }
428 response
429}
430
431#[derive(Deserialize)]
432pub struct AuthorizeForm {
433 pub csrf_token: String,
434 pub client_id: String,
435 pub redirect_uri: String,
436 pub code_challenge: String,
437 pub code_challenge_method: String,
438 pub state: Option<String>,
439 pub scope: Option<String>,
440 pub response_type: Option<String>,
441 pub token: Option<String>,
443 pub email: Option<String>,
445 pub password: Option<String>,
446}
447
448pub async fn oauth_authorize_post(
449 headers: HeaderMap,
450 State(state): State<Arc<OAuthState>>,
451 axum::Form(form): axum::Form<AuthorizeForm>,
452) -> Response {
453 let csrf_from_cookie = extract_cookie(&headers, "forge_oauth_csrf");
455 let csrf_valid = if let Some(cookie_csrf) = csrf_from_cookie {
456 cookie_csrf == form.csrf_token && state.validate_csrf(&form.csrf_token).await
457 } else {
458 false
459 };
460 if !csrf_valid {
461 return (
462 StatusCode::FORBIDDEN,
463 Json(serde_json::json!({
464 "error": "csrf_validation_failed",
465 "error_description": "Invalid or expired CSRF token. Please try again."
466 })),
467 )
468 .into_response();
469 }
470
471 let ip = client_ip(&headers);
473 let rate_key = format!("oauth_login:{ip}");
474
475 let client = sqlx::query_as::<_, (Vec<String>,)>(
477 "SELECT redirect_uris FROM forge_oauth_clients WHERE client_id = $1",
478 )
479 .bind(&form.client_id)
480 .fetch_optional(&state.pool)
481 .await;
482
483 let redirect_uris = match client {
484 Ok(Some((uris,))) => uris,
485 _ => {
486 return (
487 StatusCode::BAD_REQUEST,
488 Json(serde_json::json!({
489 "error": "invalid_client"
490 })),
491 )
492 .into_response();
493 }
494 };
495
496 if !validate_redirect_uri(&form.redirect_uri, &redirect_uris) {
497 return (
498 StatusCode::BAD_REQUEST,
499 Json(serde_json::json!({
500 "error": "invalid_redirect_uri"
501 })),
502 )
503 .into_response();
504 }
505
506 let user_id: Uuid;
508
509 let session_subject = extract_cookie(&headers, "forge_session")
510 .and_then(|v| super::auth::verify_session_cookie(&v, &state.jwt_secret));
511
512 if let Some(subject) = session_subject {
513 user_id = subject.parse::<Uuid>().unwrap_or_else(|_| {
516 use sha2::Digest;
518 let hash: [u8; 32] = sha2::Sha256::digest(subject.as_bytes()).into();
519 let mut bytes = [0u8; 16];
520 bytes.copy_from_slice(&hash[..16]);
521 Uuid::from_bytes(bytes)
522 });
523 } else if let Some(token) = &form.token {
524 match state.auth_middleware.validate_token_async(token).await {
526 Ok(claims) => {
527 user_id = claims
528 .user_id()
529 .ok_or(())
530 .map_err(|_| ())
531 .unwrap_or_default();
532 if user_id.is_nil() {
533 return authorize_error_redirect(
534 &form.redirect_uri,
535 form.state.as_deref(),
536 "access_denied",
537 "Invalid user identity in token",
538 );
539 }
540 }
541 Err(_) => {
542 return authorize_error_redirect(
543 &form.redirect_uri,
544 form.state.as_deref(),
545 "access_denied",
546 "Invalid or expired token. Please log in again.",
547 );
548 }
549 }
550 } else if let (Some(email), Some(password)) = (&form.email, &form.password) {
551 if !state.auth_is_hmac {
553 return authorize_error_redirect(
554 &form.redirect_uri,
555 form.state.as_deref(),
556 "access_denied",
557 "Direct login not supported with external auth provider",
558 );
559 }
560
561 if !state
562 .rate_limiter
563 .check(&rate_key, LOGIN_FAIL_RATE_LIMIT)
564 .await
565 {
566 return authorize_error_redirect(
567 &form.redirect_uri,
568 form.state.as_deref(),
569 "access_denied",
570 "Too many login attempts. Please try again later.",
571 );
572 }
573
574 let row = sqlx::query_as::<_, (Uuid, Option<String>, Option<String>)>(
576 "SELECT id, password_hash, role::TEXT FROM users WHERE email = $1",
577 )
578 .bind(email)
579 .fetch_optional(&state.pool)
580 .await;
581
582 match row {
583 Ok(Some((uid, Some(hash), _role))) => match bcrypt::verify(password, &hash) {
584 Ok(true) => {
585 user_id = uid;
586 }
587 _ => {
588 return authorize_error_redirect(
589 &form.redirect_uri,
590 form.state.as_deref(),
591 "access_denied",
592 "Invalid email or password",
593 );
594 }
595 },
596 _ => {
597 return authorize_error_redirect(
598 &form.redirect_uri,
599 form.state.as_deref(),
600 "access_denied",
601 "Invalid email or password",
602 );
603 }
604 }
605 } else {
606 return (
607 StatusCode::BAD_REQUEST,
608 Json(serde_json::json!({
609 "error": "invalid_request",
610 "error_description": "Must provide either a token or email/password"
611 })),
612 )
613 .into_response();
614 }
615
616 let code = oauth::generate_random_token();
618 let expires_at = Utc::now() + chrono::Duration::seconds(AUTH_CODE_TTL_SECS);
619 let scopes: Vec<String> = form
620 .scope
621 .as_deref()
622 .map(|s| s.split_whitespace().map(String::from).collect())
623 .unwrap_or_default();
624
625 let result = sqlx::query(
626 "INSERT INTO forge_oauth_codes \
627 (code, client_id, user_id, redirect_uri, code_challenge, code_challenge_method, scopes, expires_at) \
628 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
629 )
630 .bind(&code)
631 .bind(&form.client_id)
632 .bind(user_id)
633 .bind(&form.redirect_uri)
634 .bind(&form.code_challenge)
635 .bind(&form.code_challenge_method)
636 .bind(&scopes)
637 .bind(expires_at)
638 .execute(&state.pool)
639 .await;
640
641 if let Err(e) = result {
642 tracing::error!("Failed to store authorization code: {e}");
643 return authorize_error_redirect(
644 &form.redirect_uri,
645 form.state.as_deref(),
646 "server_error",
647 "Failed to generate authorization code",
648 );
649 }
650
651 let mut redirect_url = format!("{}?code={}", form.redirect_uri, urlencoding(&code));
653 if let Some(st) = &form.state {
654 redirect_url.push_str(&format!("&state={}", urlencoding(st)));
655 }
656
657 let mut response = Redirect::to(&redirect_url).into_response();
658 response
659 .headers_mut()
660 .insert("Referrer-Policy", HeaderValue::from_static("no-referrer"));
661
662 let cookie_value = super::auth::sign_session_cookie(&user_id.to_string(), &state.jwt_secret);
666 let secure_flag = if is_https(&headers) { "; Secure" } else { "" };
667 let session_cookie = format!(
668 "forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=86400{secure_flag}"
669 );
670 if let Ok(val) = HeaderValue::from_str(&session_cookie) {
671 response.headers_mut().append(header::SET_COOKIE, val);
672 }
673
674 response
675}
676
677#[derive(Deserialize)]
680pub struct TokenRequest {
681 pub grant_type: String,
682 pub code: Option<String>,
683 pub redirect_uri: Option<String>,
684 pub code_verifier: Option<String>,
685 pub client_id: Option<String>,
686 pub refresh_token: Option<String>,
687}
688
689#[derive(Serialize)]
690pub struct TokenResponse {
691 pub access_token: String,
692 pub token_type: String,
693 pub expires_in: i64,
694 pub refresh_token: String,
695}
696
697pub async fn oauth_token(
700 State(state): State<Arc<OAuthState>>,
701 headers: HeaderMap,
702 body: axum::body::Bytes,
703) -> Response {
704 let content_type = headers
705 .get(header::CONTENT_TYPE)
706 .and_then(|v| v.to_str().ok())
707 .unwrap_or("");
708
709 let req: TokenRequest = if content_type.starts_with("application/json") {
710 match serde_json::from_slice(&body) {
711 Ok(r) => r,
712 Err(e) => return token_error("invalid_request", &format!("Invalid JSON: {e}")),
713 }
714 } else {
715 match serde_urlencoded::from_bytes(&body) {
717 Ok(r) => r,
718 Err(e) => return token_error("invalid_request", &format!("Invalid form data: {e}")),
719 }
720 };
721
722 match req.grant_type.as_str() {
723 "authorization_code" => handle_code_exchange(&state, &req).await,
724 "refresh_token" => handle_refresh(&state, &req).await,
725 _ => (
726 StatusCode::BAD_REQUEST,
727 Json(serde_json::json!({
728 "error": "unsupported_grant_type"
729 })),
730 )
731 .into_response(),
732 }
733}
734
735async fn handle_code_exchange(state: &OAuthState, req: &TokenRequest) -> Response {
736 let code = match &req.code {
737 Some(c) => c,
738 None => return token_error("invalid_request", "code is required"),
739 };
740 let code_verifier = match &req.code_verifier {
741 Some(v) => v,
742 None => return token_error("invalid_request", "code_verifier is required"),
743 };
744 let redirect_uri = match &req.redirect_uri {
745 Some(r) => r,
746 None => return token_error("invalid_request", "redirect_uri is required"),
747 };
748 let client_id = match &req.client_id {
749 Some(c) => c,
750 None => return token_error("invalid_request", "client_id is required"),
751 };
752
753 let row = sqlx::query_as::<_, (String, Uuid, String, String, String, chrono::DateTime<Utc>)>(
755 "UPDATE forge_oauth_codes SET used_at = now() \
756 WHERE code = $1 AND used_at IS NULL \
757 RETURNING client_id, user_id, redirect_uri, code_challenge, code_challenge_method, expires_at"
758 )
759 .bind(code)
760 .fetch_optional(&state.pool)
761 .await;
762
763 let (
764 stored_client_id,
765 user_id,
766 stored_redirect,
767 stored_challenge,
768 challenge_method,
769 expires_at,
770 ) = match row {
771 Ok(Some(r)) => r,
772 Ok(None) => {
773 return token_error(
774 "invalid_grant",
775 "Invalid or already used authorization code",
776 );
777 }
778 Err(e) => {
779 tracing::error!("Failed to exchange authorization code: {e}");
780 return token_error("server_error", "Failed to exchange code");
781 }
782 };
783
784 if Utc::now() > expires_at {
786 return token_error("invalid_grant", "Authorization code has expired");
787 }
788
789 if *client_id != stored_client_id {
791 return token_error("invalid_grant", "client_id does not match");
792 }
793
794 if *redirect_uri != stored_redirect {
796 return token_error("invalid_grant", "redirect_uri does not match");
797 }
798
799 if challenge_method != CHALLENGE_METHOD_S256 {
800 return token_error("invalid_request", "Unsupported code_challenge_method");
801 }
802 if !forge_core::oauth::pkce::verify_s256(code_verifier, &stored_challenge) {
803 return token_error("invalid_grant", "PKCE verification failed");
804 }
805
806 let access_ttl = state.access_token_ttl_secs;
807 let refresh_ttl = state.refresh_token_ttl_days;
808
809 let pair = forge_core::auth::tokens::issue_token_pair_with_client(
810 &state.pool,
811 user_id,
812 &["user"],
813 access_ttl,
814 refresh_ttl,
815 Some(client_id),
816 mcp_token_issuer(state.token_issuer.clone()),
817 )
818 .await;
819
820 match pair {
821 Ok(pair) => (
822 StatusCode::OK,
823 Json(TokenResponse {
824 access_token: pair.access_token,
825 token_type: "Bearer".into(),
826 expires_in: access_ttl,
827 refresh_token: pair.refresh_token,
828 }),
829 )
830 .into_response(),
831 Err(e) => {
832 tracing::error!("Failed to issue token pair: {e}");
833 token_error("server_error", "Failed to issue tokens")
834 }
835 }
836}
837
838async fn handle_refresh(state: &OAuthState, req: &TokenRequest) -> Response {
839 let refresh_token = match &req.refresh_token {
840 Some(t) => t,
841 None => return token_error("invalid_request", "refresh_token is required"),
842 };
843 let client_id = req.client_id.as_deref();
844
845 let access_ttl = state.access_token_ttl_secs;
846 let refresh_ttl = state.refresh_token_ttl_days;
847
848 let pair = forge_core::auth::tokens::rotate_refresh_token_with_client(
849 &state.pool,
850 refresh_token,
851 &["user"],
852 access_ttl,
853 refresh_ttl,
854 client_id,
855 mcp_token_issuer(state.token_issuer.clone()),
856 )
857 .await;
858
859 match pair {
860 Ok(pair) => (
861 StatusCode::OK,
862 Json(TokenResponse {
863 access_token: pair.access_token,
864 token_type: "Bearer".into(),
865 expires_in: access_ttl,
866 refresh_token: pair.refresh_token,
867 }),
868 )
869 .into_response(),
870 Err(_) => token_error("invalid_grant", "Invalid or expired refresh token"),
871 }
872}
873
874fn mcp_token_issuer(
878 issuer: Arc<dyn forge_core::TokenIssuer>,
879) -> impl FnOnce(Uuid, &[&str], i64) -> forge_core::Result<String> {
880 move |uid, roles, ttl| {
881 let claims = Claims::builder()
882 .subject(uid)
883 .roles(roles.iter().map(|s| s.to_string()).collect())
884 .claim("aud".to_string(), serde_json::json!(MCP_AUDIENCE))
885 .duration_secs(ttl)
886 .build()
887 .map_err(forge_core::ForgeError::Internal)?;
888 issuer.sign(&claims)
889 }
890}
891
892fn is_https(headers: &HeaderMap) -> bool {
893 headers
894 .get("x-forwarded-proto")
895 .and_then(|v| v.to_str().ok())
896 .map(|s| s == "https")
897 .unwrap_or(false)
898}
899
900fn token_error(error: &str, description: &str) -> Response {
901 (
902 StatusCode::BAD_REQUEST,
903 Json(serde_json::json!({
904 "error": error,
905 "error_description": description
906 })),
907 )
908 .into_response()
909}
910
911fn authorize_error_redirect(
912 redirect_uri: &str,
913 state: Option<&str>,
914 error: &str,
915 description: &str,
916) -> Response {
917 let mut url = format!(
918 "{}?error={}&error_description={}",
919 redirect_uri,
920 urlencoding(error),
921 urlencoding(description),
922 );
923 if let Some(st) = state {
924 url.push_str(&format!("&state={}", urlencoding(st)));
925 }
926 Redirect::to(&url).into_response()
927}
928
929fn base_url_from_headers(headers: &HeaderMap) -> String {
930 let host = headers
931 .get("host")
932 .and_then(|v| v.to_str().ok())
933 .unwrap_or("localhost:9081");
934
935 let scheme = headers
936 .get("x-forwarded-proto")
937 .and_then(|v| v.to_str().ok())
938 .unwrap_or("http");
939
940 format!("{scheme}://{host}")
941}
942
943fn client_ip(headers: &HeaderMap) -> String {
944 headers
945 .get("x-forwarded-for")
946 .and_then(|v| v.to_str().ok())
947 .and_then(|s| s.split(',').next())
948 .map(|s| s.trim().to_string())
949 .or_else(|| {
950 headers
951 .get("x-real-ip")
952 .and_then(|v| v.to_str().ok())
953 .map(String::from)
954 })
955 .unwrap_or_else(|| "unknown".to_string())
956}
957
958fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
959 headers
960 .get(header::COOKIE)
961 .and_then(|v| v.to_str().ok())
962 .and_then(|cookies| {
963 cookies.split(';').map(|c| c.trim()).find_map(|c| {
964 let (k, v) = c.split_once('=')?;
965 if k == name { Some(v.to_string()) } else { None }
966 })
967 })
968}
969
970fn html_escape(s: &str) -> String {
971 s.replace('&', "&")
972 .replace('<', "<")
973 .replace('>', ">")
974 .replace('"', """)
975 .replace('\'', "'")
976}
977
978fn urlencoding(s: &str) -> String {
979 let mut result = String::with_capacity(s.len());
981 for b in s.bytes() {
982 match b {
983 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
984 result.push(b as char);
985 }
986 _ => {
987 result.push_str(&format!("%{b:02X}"));
988 }
989 }
990 }
991 result
992}