1use axum::{
2 extract::{Form, Json, Path, Query, State},
3 http::{header, HeaderMap, StatusCode},
4 response::{Html, IntoResponse, Redirect},
5 routing::{get, post},
6 Router,
7};
8use base64::{engine::general_purpose, Engine};
9use chrono::{Duration, Utc};
10use http::HeaderValue;
11use jsonwebtoken::jwk::{CommonParameters, Jwk};
12use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header};
13use once_cell::sync::Lazy;
14use rsa::pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey};
15use rsa::traits::PublicKeyParts;
16use rsa::RsaPrivateKey;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use sha2::Digest;
20use std::{
21 collections::{HashMap, HashSet},
22 net::SocketAddr,
23 sync::{Arc, RwLock},
24};
25use tokio::{net::TcpListener, task::JoinHandle};
26use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
27use uuid::Uuid;
28
29pub const KID: &str = "oauth-by-rustmcp";
30
31use serde_json::Value;
33
34#[derive(Debug, Clone)]
35pub struct IssuerConfig {
36 pub scheme: String,
37 pub host: String,
38 pub port: u16,
39
40 pub scopes_supported: HashSet<String>,
42 pub claims_supported: Vec<String>,
43 pub grant_types_supported: HashSet<String>,
44 pub response_types_supported: HashSet<String>,
45 pub token_endpoint_auth_methods_supported: HashSet<String>,
46 pub code_challenge_methods_supported: HashSet<String>,
47 pub subject_types_supported: Vec<String>,
48 pub id_token_signing_alg_values_supported: Vec<String>,
49 pub generate_client_secret_for_dcr: bool,
50 pub allowed_origins: Vec<String>,
51}
52
53impl Default for IssuerConfig {
54 fn default() -> Self {
55 let mut scopes = HashSet::new();
56 scopes.extend([
57 "openid".into(),
58 "profile".into(),
59 "email".into(),
60 "offline_access".into(),
61 "address".into(),
62 "phone".into(),
63 ]);
64
65 let mut grants = HashSet::new();
66 grants.extend([
67 "authorization_code".into(),
68 "refresh_token".into(),
69 "client_credentials".into(),
70 ]);
71
72 let mut auth_methods = HashSet::new();
73 auth_methods.extend([
74 "client_secret_basic".into(),
75 "client_secret_post".into(),
76 "none".into(),
77 "private_key_jwt".into(),
78 ]);
79
80 Self {
81 scheme: "http".into(),
82 host: "localhost".into(),
83 port: 0, scopes_supported: scopes,
85 claims_supported: vec![
86 "sub".into(),
87 "name".into(),
88 "given_name".into(),
89 "family_name".into(),
90 "email".into(),
91 "email_verified".into(),
92 "picture".into(),
93 "locale".into(),
94 ],
95 generate_client_secret_for_dcr: true,
96 grant_types_supported: grants,
97 response_types_supported: ["code".into(), "token".into(), "id_token".into()].into(),
98 token_endpoint_auth_methods_supported: auth_methods,
99 code_challenge_methods_supported: ["plain".into(), "S256".into()].into(),
100 subject_types_supported: vec!["public".into()],
101 id_token_signing_alg_values_supported: vec!["RS256".into()],
102 allowed_origins: vec![
103 "http://localhost:3001".to_string(),
104 "http://localhost:8080".to_string(),
105 "http://localhost:6274".to_string(),
106 ],
107 }
108 }
109}
110
111impl IssuerConfig {
112 pub fn to_discovery_document(&self, issuer: String) -> Value {
113 let iss = issuer;
114 json!({
115 "issuer": iss,
116 "authorization_endpoint": format!("{}/authorize", iss),
117 "token_endpoint": format!("{}/token", iss),
118 "userinfo_endpoint": format!("{}/userinfo", iss),
119 "jwks_uri": format!("{}/.well-known/jwks.json", iss),
120 "registration_endpoint": format!("{}/register", iss),
121 "revocation_endpoint": format!("{}/revoke", iss),
122 "introspection_endpoint": format!("{}/introspect", iss),
123 "scopes_supported": self.scopes_supported.iter().collect::<Vec<_>>(),
124 "claims_supported": &self.claims_supported,
125 "grant_types_supported": self.grant_types_supported.iter().collect::<Vec<_>>(),
126 "response_types_supported": self.response_types_supported.iter().collect::<Vec<_>>(),
127 "token_endpoint_auth_methods_supported": self.token_endpoint_auth_methods_supported.iter().collect::<Vec<_>>(),
128 "code_challenge_methods_supported": self.code_challenge_methods_supported.iter().collect::<Vec<_>>(),
129 "subject_types_supported": &self.subject_types_supported,
130 "id_token_signing_alg_values_supported": &self.id_token_signing_alg_values_supported,
131 })
132 }
133
134 pub fn validate_scope(&self, scope: &str) -> Result<String, String> {
135 let requested: HashSet<_> = scope.split_whitespace().map(|s| s.to_string()).collect();
136 let unknown: Vec<_> = requested
137 .difference(&self.scopes_supported)
138 .cloned()
139 .collect();
140 if unknown.is_empty() {
141 Ok(scope.to_string())
142 } else {
143 Err(format!("invalid_scope: {}", unknown.join(" ")))
144 }
145 }
146
147 pub fn validate_grant_type(&self, grant: &str) -> bool {
148 self.grant_types_supported.contains(grant)
149 }
150}
151
152#[derive(Clone)]
178pub struct AppState {
179 pub config: Arc<IssuerConfig>,
180 pub base_url: String,
181 pub clients: Arc<RwLock<HashMap<String, Client>>>,
182 pub codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
183 pub tokens: Arc<RwLock<HashMap<String, Token>>>,
184 pub refresh_tokens: Arc<RwLock<HashMap<String, Token>>>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct Client {
189 pub client_id: String,
190 pub client_secret: Option<String>,
191 pub redirect_uris: Vec<String>,
192 pub grant_types: Vec<String>,
193 pub response_types: Vec<String>,
194 pub scope: String,
195 pub token_endpoint_auth_method: String,
196 pub client_name: Option<String>,
197 pub client_uri: Option<String>,
198 pub logo_uri: Option<String>,
199 pub contacts: Vec<String>,
200 pub policy_uri: Option<String>,
201 pub tos_uri: Option<String>,
202 pub jwks: Option<serde_json::Value>,
203 pub jwks_uri: Option<String>,
204 pub software_id: Option<String>,
205 pub software_version: Option<String>,
206 pub registration_access_token: Option<String>,
207 pub registration_client_uri: Option<String>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct AuthorizationCode {
212 pub code: String,
213 pub client_id: String,
214 pub redirect_uri: String,
215 pub scope: String,
216 pub expires_at: chrono::DateTime<Utc>,
217 pub code_challenge: Option<String>,
218 pub code_challenge_method: Option<String>,
219 pub user_id: String,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct Token {
224 access_token: String,
225 refresh_token: Option<String>,
226 client_id: String,
227 scope: String,
228 expires_at: chrono::DateTime<Utc>,
229 user_id: String,
230 revoked: bool,
231}
232
233impl AppState {
234 pub fn new(config: IssuerConfig) -> Self {
235 let base_url = format!("{}://{}:{}", config.scheme, config.host, config.port);
236 Self {
237 config: Arc::new(config),
238 clients: Arc::new(RwLock::new(HashMap::new())),
239 codes: Arc::new(RwLock::new(HashMap::new())),
240 tokens: Arc::new(RwLock::new(HashMap::new())),
241 refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
242 base_url,
243 }
244 }
245
246 pub fn issuer(&self) -> &str {
247 self.base_url.as_str()
248 }
249
250 pub fn register_client(
251 &self,
252 metadata: serde_json::Value,
253 ) -> Result<Client, (StatusCode, Json<serde_json::Value>)> {
254 let requested_scope = metadata
255 .get("scope")
256 .and_then(|v| v.as_str())
257 .unwrap_or("openid");
258
259 self.config
260 .validate_scope(requested_scope)
261 .map_err(|err| (StatusCode::BAD_REQUEST, Json(json!({"error": err}))))?;
262
263 let client_id = Uuid::new_v4().to_string();
264
265 let client_secret = if self.config.generate_client_secret_for_dcr
266 || metadata
267 .get("token_endpoint_auth_method")
268 .and_then(|v| v.as_str())
269 != Some("none")
270 {
271 Some(generate_token())
272 } else {
273 None
274 };
275
276 let redirect_uris = metadata
277 .get("redirect_uris")
278 .and_then(|v| v.as_array())
279 .map(|arr| {
280 arr.iter()
281 .filter_map(|u| u.as_str().map(|s| s.to_string()))
282 .collect::<Vec<String>>()
283 })
284 .unwrap_or_default();
285
286 if redirect_uris.is_empty()
287 && metadata.get("grant_types").map(|v| {
288 v.as_array()
289 .map(|a| a.contains(&json!("client_credentials")))
290 }) != Some(Some(true))
291 {
292 return Err((
293 StatusCode::BAD_REQUEST,
294 Json(json!({"error": "redirect_uris required"})),
295 ));
296 }
297
298 let client = Client {
299 client_id: client_id.clone(),
300 client_secret: client_secret.clone(),
301 redirect_uris,
302 grant_types: metadata
303 .get("grant_types")
304 .and_then(|v| v.as_array())
305 .map(|arr| {
306 arr.iter()
307 .filter_map(|v| v.as_str().map(|s| s.to_string()))
308 .collect()
309 })
310 .unwrap_or_else(|| vec!["authorization_code".to_string()]),
311 response_types: metadata
312 .get("response_types")
313 .and_then(|v| v.as_array())
314 .map(|arr| {
315 arr.iter()
316 .filter_map(|v| v.as_str().map(|s| s.to_string()))
317 .collect()
318 })
319 .unwrap_or_else(|| vec!["code".to_string()]),
320 scope: metadata
321 .get("scope")
322 .and_then(|v| v.as_str())
323 .unwrap_or("")
324 .to_string(),
325 token_endpoint_auth_method: metadata
326 .get("token_endpoint_auth_method")
327 .and_then(|v| v.as_str())
328 .unwrap_or("client_secret_basic")
329 .to_string(),
330 client_name: metadata
331 .get("client_name")
332 .and_then(|v| v.as_str())
333 .map(|s| s.to_string()),
334 client_uri: metadata
335 .get("client_uri")
336 .and_then(|v| v.as_str())
337 .map(|s| s.to_string()),
338 logo_uri: metadata
339 .get("logo_uri")
340 .and_then(|v| v.as_str())
341 .map(|s| s.to_string()),
342 contacts: metadata
343 .get("contacts")
344 .and_then(|v| v.as_array())
345 .map(|arr| {
346 arr.iter()
347 .filter_map(|v| v.as_str().map(|s| s.to_string()))
348 .collect()
349 })
350 .unwrap_or_default(),
351 policy_uri: metadata
352 .get("policy_uri")
353 .and_then(|v| v.as_str())
354 .map(|s| s.to_string()),
355 tos_uri: metadata
356 .get("tos_uri")
357 .and_then(|v| v.as_str())
358 .map(|s| s.to_string()),
359 jwks: metadata.get("jwks").cloned(),
360 jwks_uri: metadata
361 .get("jwks_uri")
362 .and_then(|v| v.as_str())
363 .map(|s| s.to_string()),
364 software_id: metadata
365 .get("software_id")
366 .and_then(|v| v.as_str())
367 .map(|s| s.to_string()),
368 software_version: metadata
369 .get("software_version")
370 .and_then(|v| v.as_str())
371 .map(|s| s.to_string()),
372 registration_access_token: None,
373 registration_client_uri: Some(format!("{}/register/{}", self.issuer(), client_id)),
374 };
375
376 self.clients
377 .write()
378 .unwrap()
379 .insert(client_id.clone(), client.clone());
380
381 Ok(client)
382 }
383
384 pub fn generate_jwt(
385 &self,
386 client: &Client,
387 options: crate::testkit::JwtOptions,
388 ) -> Result<String, jsonwebtoken::errors::Error> {
389 let scope = options.scope.unwrap_or_else(|| client.scope.clone());
390 issue_jwt(
391 self.issuer(),
392 &client.client_id,
393 &options.user_id,
394 &scope,
395 options.expires_in,
396 )
397 }
398
399 pub fn router(self) -> Router {
400 let cors = build_cors_layer(&self.config);
401 Router::new()
402 .route(
403 "/.well-known/openid-configuration",
404 get(well_known_openid_configuration),
405 )
406 .route("/.well-known/jwks.json", get(jwks))
407 .route("/register", post(register_client))
408 .route("/register/{client_id}", get(get_client))
409 .route("/authorize", get(authorize))
410 .route("/token", post(token_endpoint))
411 .route("/introspect", post(introspect))
412 .route("/revoke", post(revoke))
413 .route("/userinfo", get(userinfo))
414 .route("/error", get(error_page))
415 .with_state(self)
416 .layer(cors)
417 }
418
419 pub async fn start(mut self) -> (SocketAddr, JoinHandle<()>) {
420 let port = self.config.port;
421 let addr = SocketAddr::from(([127, 0, 0, 1], port));
422 let listener = TcpListener::bind(addr).await.unwrap();
423 let local_addr = listener.local_addr().unwrap();
424 let base_url = format!(
425 "{}://{}:{}",
426 self.config.scheme,
427 self.config.host,
428 local_addr.port()
429 );
430 self.base_url = base_url;
431
432 let router = self.router();
433 let handle = tokio::spawn(async move {
434 axum::serve(listener, router).await.unwrap();
435 });
436 (local_addr, handle)
437 }
438}
439
440fn generate_code() -> String {
441 Uuid::new_v4().to_string()[..20].to_string()
442}
443
444fn generate_token() -> String {
445 format!("tok_{}", Uuid::new_v4().to_string().replace("-", ""))
446}
447
448fn issue_jwt(
449 issuer: &str,
450 client_id: &str,
451 user_id: &str,
452 requested_scope: &str,
453 expires_in: i64,
454) -> Result<String, jsonwebtoken::errors::Error> {
455 let iat = Utc::now().timestamp() as usize;
456 let exp = (Utc::now() + Duration::seconds(expires_in)).timestamp() as usize;
457
458 let scopes: Vec<&str> = requested_scope.split_whitespace().collect();
460
461 let claims = Claims {
463 iss: issuer.to_string(),
464 sub: user_id.to_string(),
465 aud: client_id.to_string(),
466 exp,
467 iat,
468 scope: Some(scopes.join(" ")), auth_time: Some(iat),
470 typ: "Bearer".to_string(), azp: Some(client_id.to_string()),
473 sid: Some(format!("sid-{}", Uuid::new_v4())),
474 jti: Uuid::new_v4().to_string(),
475 };
476
477 let mut header = Header::new(Algorithm::RS256);
479 header.typ = Some("JWT".to_string());
480 header.kid = Some(KID.to_string());
481
482 encode(&header, &claims, &KEYS.encoding)
484}
485
486pub static KEYS: Lazy<Keys> = Lazy::new(|| {
488 let mut rng = rand::thread_rng();
489 let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate key");
490 let public_key = private_key.to_public_key();
491
492 let private_pem = private_key
494 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
495 .expect("failed to encode private key")
496 .to_string();
497
498 let public_pem = public_key
499 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
500 .expect("failed to encode public key")
501 .to_string();
502
503 let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap();
504 let decoding_key = DecodingKey::from_rsa_pem(public_pem.as_bytes()).unwrap();
505
506 Keys {
507 encoding: encoding_key,
508 decoding: decoding_key,
509 public_pem, }
511});
512
513static JWKS_JSON: Lazy<serde_json::Value> = Lazy::new(|| {
514 let public_key = rsa::RsaPublicKey::from_public_key_pem(&KEYS.public_pem)
515 .expect("Failed to parse stored public key");
516
517 let jwk = Jwk {
518 common: CommonParameters {
519 key_algorithm: Some(jsonwebtoken::jwk::KeyAlgorithm::RS256),
520 key_id: Some(KID.to_string()),
521 ..Default::default()
522 },
523 algorithm: jsonwebtoken::jwk::AlgorithmParameters::RSA(
524 jsonwebtoken::jwk::RSAKeyParameters {
525 n: general_purpose::STANDARD.encode(public_key.n().to_bytes_be()),
526 e: general_purpose::STANDARD.encode(public_key.e().to_bytes_be()),
527 key_type: jsonwebtoken::jwk::RSAKeyType::RSA,
528 },
529 ),
530 };
531
532 json!({ "keys": [jwk] })
533});
534
535#[allow(unused)]
536pub struct Keys {
537 pub encoding: EncodingKey,
538 pub decoding: DecodingKey,
539 pub public_pem: String,
540}
541
542#[derive(Debug, Clone, Serialize, Deserialize)]
543struct AccessToken {
544 access_token: String,
545 token_type: String,
546 expires_in: i64,
547 scope: Option<String>,
548 refresh_token: Option<String>,
549 id_token: Option<String>,
550}
551
552#[derive(Debug, Clone, Serialize, Deserialize)]
553struct Claims {
554 iss: String,
555 sub: String,
556 aud: String,
557 exp: usize,
558 iat: usize,
559 scope: Option<String>,
560 auth_time: Option<usize>,
561 typ: String, azp: Option<String>, sid: Option<String>, jti: String, }
566
567async fn well_known_openid_configuration(State(state): State<AppState>) -> impl IntoResponse {
569 let discovery = state.config.to_discovery_document(state.base_url);
570 (StatusCode::OK, Json(discovery))
571}
572
573async fn jwks() -> impl IntoResponse {
574 (StatusCode::OK, Json(JWKS_JSON.clone()))
575}
576
577async fn register_client(
578 State(state): State<AppState>,
579 Json(metadata): Json<serde_json::Value>,
580) -> impl IntoResponse {
581 let requested_scope = metadata
582 .get("scope")
583 .and_then(|v| v.as_str())
584 .unwrap_or("openid");
585
586 match state.config.validate_scope(requested_scope) {
587 Ok(_) => { }
588 Err(e) => {
589 return (
590 StatusCode::BAD_REQUEST,
591 Json(json!({ "error": "invalid_scope", "error_description": e })),
592 );
593 }
594 };
595
596 let client_id = Uuid::new_v4().to_string();
597
598 let client_secret = if state.config.generate_client_secret_for_dcr
599 || metadata
600 .get("token_endpoint_auth_method")
601 .and_then(|v| v.as_str())
602 != Some("none")
603 {
604 Some(generate_token())
605 } else {
606 None
607 };
608
609 let redirect_uris = metadata
610 .get("redirect_uris")
611 .and_then(|v| v.as_array())
612 .map(|arr| {
613 arr.iter()
614 .filter_map(|u| u.as_str().map(|s| s.to_string()))
615 .collect::<Vec<String>>()
616 })
617 .unwrap_or_default();
618
619 if redirect_uris.is_empty()
620 && metadata.get("grant_types").map(|v| {
621 v.as_array()
622 .map(|a| a.contains(&json!("client_credentials")))
623 }) != Some(Some(true))
624 {
625 return (
626 StatusCode::BAD_REQUEST,
627 Json(json!({"error": "redirect_uris required"})),
628 );
629 }
630
631 let client = Client {
632 client_id: client_id.clone(),
633 client_secret: client_secret.clone(),
634 redirect_uris,
635 grant_types: metadata
636 .get("grant_types")
637 .and_then(|v| v.as_array())
638 .map(|arr| {
639 arr.iter()
640 .filter_map(|v| v.as_str().map(|s| s.to_string()))
641 .collect()
642 })
643 .unwrap_or_else(|| vec!["authorization_code".to_string()]),
644 response_types: metadata
645 .get("response_types")
646 .and_then(|v| v.as_array())
647 .map(|arr| {
648 arr.iter()
649 .filter_map(|v| v.as_str().map(|s| s.to_string()))
650 .collect()
651 })
652 .unwrap_or_else(|| vec!["code".to_string()]),
653 scope: metadata
654 .get("scope")
655 .and_then(|v| v.as_str())
656 .unwrap_or("")
657 .to_string(),
658 token_endpoint_auth_method: metadata
659 .get("token_endpoint_auth_method")
660 .and_then(|v| v.as_str())
661 .unwrap_or("client_secret_basic")
662 .to_string(),
663 client_name: metadata
664 .get("client_name")
665 .and_then(|v| v.as_str())
666 .map(|s| s.to_string()),
667 client_uri: metadata
668 .get("client_uri")
669 .and_then(|v| v.as_str())
670 .map(|s| s.to_string()),
671 logo_uri: metadata
672 .get("logo_uri")
673 .and_then(|v| v.as_str())
674 .map(|s| s.to_string()),
675 contacts: metadata
676 .get("contacts")
677 .and_then(|v| v.as_array())
678 .map(|arr| {
679 arr.iter()
680 .filter_map(|v| v.as_str().map(|s| s.to_string()))
681 .collect()
682 })
683 .unwrap_or_default(),
684 policy_uri: metadata
685 .get("policy_uri")
686 .and_then(|v| v.as_str())
687 .map(|s| s.to_string()),
688 tos_uri: metadata
689 .get("tos_uri")
690 .and_then(|v| v.as_str())
691 .map(|s| s.to_string()),
692 jwks: metadata.get("jwks").cloned(),
693 jwks_uri: metadata
694 .get("jwks_uri")
695 .and_then(|v| v.as_str())
696 .map(|s| s.to_string()),
697 software_id: metadata
698 .get("software_id")
699 .and_then(|v| v.as_str())
700 .map(|s| s.to_string()),
701 software_version: metadata
702 .get("software_version")
703 .and_then(|v| v.as_str())
704 .map(|s| s.to_string()),
705 registration_access_token: None,
706 registration_client_uri: Some(format!("{}/register/{}", state.issuer(), client_id)),
707 };
708
709 state
710 .clients
711 .write()
712 .unwrap()
713 .insert(client_id.clone(), client.clone());
714
715 let response = json!({
716 "client_id": client.client_id,
717 "client_secret": client.client_secret,
718 "client_id_issued_at": Utc::now().timestamp(),
719 "registration_client_uri": client.registration_client_uri,
720 "registration_access_token": Uuid::new_v4().to_string(),
721 "redirect_uris": client.redirect_uris,
722 "grant_types": client.grant_types,
723 "response_types": client.response_types,
724 "scope": client.scope,
725 "token_endpoint_auth_method": client.token_endpoint_auth_method
726 });
727
728 (StatusCode::CREATED, Json(response))
729}
730
731async fn get_client(
732 State(state): State<AppState>,
733 Path(client_id): Path<String>,
734) -> impl IntoResponse {
735 if let Some(client) = state.clients.read().unwrap().get(&client_id) {
736 let response = json!({
737 "client_id": client.client_id,
738 "client_name": client.client_name,
739 "redirect_uris": client.redirect_uris,
740 "grant_types": client.grant_types,
741 "scope": client.scope
742 });
743 (StatusCode::OK, Json(response))
744 } else {
745 (
746 StatusCode::NOT_FOUND,
747 Json(json!({"error": "client not found"})),
748 )
749 }
750}
751
752#[derive(Deserialize)]
753struct AuthorizeQuery {
754 response_type: String,
755 client_id: String,
756 redirect_uri: Option<String>,
757 scope: Option<String>,
758 state: Option<String>,
759 code_challenge: Option<String>,
760 code_challenge_method: Option<String>,
761}
762
763async fn authorize(
764 State(state): State<AppState>,
765 Query(params): Query<AuthorizeQuery>,
766) -> impl IntoResponse {
767 let clients = state.clients.read().unwrap();
768 let client = match clients.get(¶ms.client_id) {
769 Some(c) => c,
770 None => {
771 return Redirect::to(&format!(
772 "/error?error=invalid_client&state={}",
773 params.state.as_deref().unwrap_or("")
774 ))
775 .into_response();
776 }
777 };
778
779 if params.response_type != "code" {
780 return Redirect::to(&format!(
781 "/error?error=unsupported_response_type&state={}",
782 params.state.as_deref().unwrap_or("")
783 ))
784 .into_response();
785 }
786
787 let redirect_uri = match ¶ms.redirect_uri {
788 Some(uri) => {
789 if !client.redirect_uris.contains(uri) {
790 return Redirect::to(&format!(
791 "/error?error=invalid_request&state={}",
792 params.state.as_deref().unwrap_or("")
793 ))
794 .into_response();
795 }
796 uri.clone()
797 }
798 None => client.redirect_uris.first().unwrap().clone(),
799 };
800
801 let code = generate_code();
802
803 let requested_scopes: HashSet<String> = params
805 .scope
806 .clone()
807 .unwrap_or_default()
808 .split_whitespace()
809 .map(|s| s.to_string())
810 .collect();
811
812 let registered_scopes: HashSet<String> = client
813 .scope
814 .split_whitespace()
815 .map(|s| s.to_string())
816 .collect();
817
818 let granted_scopes: Vec<String> = requested_scopes
819 .intersection(®istered_scopes)
820 .cloned()
821 .collect();
822
823 let final_scope = granted_scopes.join(" ");
824
825 let auth_code = AuthorizationCode {
826 code: code.clone(),
827 client_id: params.client_id.clone(),
828 redirect_uri: redirect_uri.clone(),
829 scope: final_scope, expires_at: Utc::now() + Duration::minutes(10),
831 code_challenge: params.code_challenge.clone(),
832 code_challenge_method: params.code_challenge_method.clone(),
833 user_id: "test-user-123".to_string(),
834 };
835
836 state.codes.write().unwrap().insert(code.clone(), auth_code);
837
838 let redirect_url = format!(
839 "{}?code={}&state={}",
840 redirect_uri,
841 code,
842 params.state.as_deref().unwrap_or("")
843 );
844
845 Redirect::to(&redirect_url).into_response()
846}
847
848#[derive(Deserialize)]
849struct TokenRequest {
850 grant_type: String,
851 code: Option<String>,
852 _redirect_uri: Option<String>,
853 client_id: Option<String>,
854 _client_secret: Option<String>,
855 refresh_token: Option<String>,
856 code_verifier: Option<String>,
857 scope: Option<String>,
858}
859
860async fn token_endpoint(
861 State(state): State<AppState>,
862 _headers: HeaderMap,
863 Form(form): Form<TokenRequest>,
864) -> impl IntoResponse {
865 if form.grant_type == "authorization_code" {
866 let code = form.code.as_deref().unwrap_or("");
867 let code_obj = match state.codes.write().unwrap().remove(code) {
868 Some(c) => c,
869 None => {
870 return (
871 StatusCode::BAD_REQUEST,
872 Json(json!({"error": "invalid_grant"})),
873 )
874 .into_response();
875 }
876 };
877
878 if code_obj.expires_at < Utc::now() {
879 return (
880 StatusCode::BAD_REQUEST,
881 Json(json!({"error": "invalid_grant"})),
882 )
883 .into_response();
884 }
885
886 if let (Some(challenge), Some(verifier)) = (&code_obj.code_challenge, &form.code_verifier) {
887 let method = code_obj.code_challenge_method.as_deref().unwrap_or("plain");
888 let computed = if method == "S256" {
889 general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(verifier.as_bytes()))
890 } else {
891 verifier.clone()
892 };
893 if computed != *challenge {
894 return (
895 StatusCode::BAD_REQUEST,
896 Json(json!({"error": "invalid_grant"})),
897 )
898 .into_response();
899 }
900 }
901
902 let refresh_token = generate_token();
904
905 let jwt = issue_jwt(
906 state.issuer(),
907 &code_obj.client_id,
908 &code_obj.user_id,
909 &code_obj.scope,
910 3600,
911 )
912 .unwrap();
913
914 let token = Token {
915 access_token: jwt.clone(),
916 refresh_token: Some(refresh_token.clone()),
917 client_id: code_obj.client_id.clone(),
918 scope: code_obj.scope.clone(),
919 expires_at: Utc::now() + Duration::hours(1),
920 user_id: code_obj.user_id.clone(),
921 revoked: false,
922 };
923
924 state
925 .tokens
926 .write()
927 .unwrap()
928 .insert(jwt.clone(), token.clone());
929 state
930 .refresh_tokens
931 .write()
932 .unwrap()
933 .insert(refresh_token.clone(), token);
934
935 let response = json!({
936 "access_token": jwt,
937 "token_type": "Bearer",
938 "expires_in": 3600,
939 "refresh_token": refresh_token,
940 "scope": code_obj.scope
941 });
942
943 (StatusCode::OK, Json(response)).into_response()
944 } else if form.grant_type == "refresh_token" {
945 let rt = form.refresh_token.as_deref().unwrap_or("");
946 let mut guard = state.refresh_tokens.write().unwrap();
947 let token = match guard.get_mut(rt) {
948 Some(t) => t,
949 None => {
950 return (
951 StatusCode::BAD_REQUEST,
952 Json(json!({"error": "invalid_grant"})),
953 )
954 .into_response();
955 }
956 };
957
958 if token.revoked {
959 return (
960 StatusCode::BAD_REQUEST,
961 Json(json!({"error": "invalid_grant"})),
962 )
963 .into_response();
964 }
965
966 let new_access_token = issue_jwt(
967 state.issuer(),
968 &token.client_id,
969 &token.user_id,
970 &token.scope,
971 3600,
972 )
973 .unwrap();
974 let new_refresh_token = generate_token();
975
976 let new_token = Token {
977 access_token: new_access_token.clone(),
978 refresh_token: Some(new_refresh_token.clone()),
979 client_id: token.client_id.clone(),
980 scope: token.scope.clone(),
981 expires_at: Utc::now() + Duration::hours(1),
982 user_id: token.user_id.clone(),
983 revoked: false,
984 };
985
986 state
987 .tokens
988 .write()
989 .unwrap()
990 .insert(new_access_token.clone(), new_token.clone());
991 state
992 .refresh_tokens
993 .write()
994 .unwrap()
995 .insert(new_refresh_token.clone(), new_token);
996
997 token.revoked = true;
998
999 let response = json!({
1000 "access_token": new_access_token,
1001 "token_type": "Bearer",
1002 "expires_in": 3600,
1003 "refresh_token": new_refresh_token,
1004 "scope": token.scope
1005 });
1006
1007 (StatusCode::OK, Json(response)).into_response()
1008 } else if form.grant_type == "client_credentials" {
1009 let client_id = form.client_id.as_deref().unwrap_or("");
1010 let clients = state.clients.read().unwrap();
1011 let client = match clients.get(client_id) {
1012 Some(c) => c,
1013 None => {
1014 return (
1015 StatusCode::BAD_REQUEST,
1016 Json(json!({"error": "invalid_client"})),
1017 )
1018 .into_response();
1019 }
1020 };
1021
1022 let requested_scopes: HashSet<String> = form
1024 .scope
1025 .as_deref()
1026 .unwrap_or("")
1027 .split_whitespace()
1028 .map(|s| s.to_string())
1029 .collect();
1030
1031 if let Some(requested_scope) = form.scope.as_deref() {
1032 if let Err(e) = state.config.validate_scope(requested_scope) {
1034 return (
1035 StatusCode::BAD_REQUEST,
1036 Json(json!({
1037 "error": "invalid_scope",
1038 "error_description": e
1039 })),
1040 )
1041 .into_response();
1042 }
1043
1044 let client_scopes: HashSet<_> = client.scope.split_whitespace().collect();
1046 let requested_scopes: HashSet<_> = requested_scope.split_whitespace().collect();
1047
1048 let not_permitted: Vec<_> = requested_scopes
1049 .difference(&client_scopes)
1050 .cloned()
1051 .collect();
1052 if !not_permitted.is_empty() {
1053 return (StatusCode::BAD_REQUEST, Json(json!({
1054 "error": "invalid_scope",
1055 "error_description": format!("Client not authorized for scopes: {}", not_permitted.join(" "))
1056 }))).into_response();
1057 }
1058 }
1059
1060 let registered_scopes: HashSet<String> = client
1062 .scope
1063 .split_whitespace()
1064 .map(|s| s.to_string())
1065 .collect();
1066
1067 let granted_scopes: Vec<String> = requested_scopes
1069 .intersection(®istered_scopes)
1070 .cloned()
1071 .collect();
1072
1073 if granted_scopes.is_empty() && !requested_scopes.is_empty() {
1075 return (
1076 StatusCode::BAD_REQUEST,
1077 Json(json!({
1078 "error": "invalid_scope",
1079 "error_description": "Requested scopes not allowed for this client"
1080 })),
1081 )
1082 .into_response();
1083 }
1084
1085 let final_scope = if requested_scopes.is_empty() {
1087 client.scope.clone()
1089 } else {
1090 granted_scopes.join(" ")
1091 };
1092
1093 let access_token =
1095 issue_jwt(state.issuer(), client_id, "client", &final_scope, 3600).unwrap();
1096
1097 let response = json!({
1098 "access_token": access_token,
1099 "token_type": "Bearer",
1100 "expires_in": 3600,
1101 "scope": final_scope
1102 });
1103
1104 (StatusCode::OK, Json(response)).into_response()
1105 } else {
1106 (
1107 StatusCode::BAD_REQUEST,
1108 Json(json!({"error": "unsupported_grant_type"})),
1109 )
1110 .into_response()
1111 }
1112}
1113
1114async fn introspect(
1115 State(state): State<AppState>,
1116 Form(form): Form<HashMap<String, String>>,
1117) -> impl IntoResponse {
1118 let token = form.get("token").cloned().unwrap_or_default();
1119 let token_data = state.tokens.read().unwrap().get(&token).cloned();
1120
1121 if let Some(t) = token_data {
1122 let active = !t.revoked && t.expires_at > Utc::now();
1123 let response = json!({
1124 "active": active,
1125 "client_id": t.client_id,
1126 "scope": t.scope,
1127 "sub": t.user_id,
1128 "exp": t.expires_at.timestamp(),
1129 "iat": (t.expires_at - Duration::hours(1)).timestamp(),
1130 "token_type": "Bearer"
1131 });
1132 (StatusCode::OK, Json(response)).into_response()
1133 } else {
1134 (StatusCode::OK, Json(json!({"active": false}))).into_response()
1135 }
1136}
1137
1138async fn revoke(
1139 State(state): State<AppState>,
1140 Form(form): Form<HashMap<String, String>>,
1141) -> impl IntoResponse {
1142 let token = form.get("token").cloned().unwrap_or_default();
1143 if let Some(t) = state.tokens.write().unwrap().get_mut(&token) {
1144 t.revoked = true;
1145 }
1146 if let Some(t) = state.refresh_tokens.write().unwrap().get_mut(&token) {
1147 t.revoked = true;
1148 }
1149 (StatusCode::OK, Json(json!({}))).into_response()
1150}
1151
1152async fn userinfo(headers: HeaderMap, State(state): State<AppState>) -> impl IntoResponse {
1153 let auth = headers.get("Authorization").and_then(|v| v.to_str().ok());
1154 if let Some(auth) = auth {
1155 if let Some(token) = auth.strip_prefix("Bearer ") {
1156 if let Some(t) = state.tokens.read().unwrap().get(token) {
1157 if t.revoked || t.expires_at < Utc::now() {
1158 return (
1159 StatusCode::UNAUTHORIZED,
1160 Json(json!({"error": "invalid_token"})),
1161 )
1162 .into_response();
1163 }
1164 let response = json!({
1165 "sub": t.user_id,
1166 "name": "Test User",
1167 "email": "test@example.com",
1168 "picture": "https://example.com/avatar.jpg"
1169 });
1170 return (StatusCode::OK, Json(response)).into_response();
1171 }
1172 }
1173 }
1174 (
1175 StatusCode::UNAUTHORIZED,
1176 Json(json!({"error": "invalid_token"})),
1177 )
1178 .into_response()
1179}
1180
1181async fn error_page(Query(params): Query<HashMap<String, String>>) -> Html<String> {
1182 let error = params.get("error").map(|s| s.as_str()).unwrap_or("unknown");
1183 Html(format!("<h1>OAuth Error: {}</h1>", error))
1184}
1185
1186fn build_cors_layer(config: &IssuerConfig) -> CorsLayer {
1187 let allowed_origins: Vec<HeaderValue> = config
1188 .allowed_origins
1189 .iter()
1190 .filter_map(|s| s.parse().ok())
1191 .collect();
1192
1193 let allowed = if allowed_origins.is_empty() {
1194 AllowOrigin::any()
1195 } else {
1196 AllowOrigin::list(allowed_origins)
1197 };
1198
1199 let allowed_methods =
1200 AllowMethods::list([http::Method::GET, http::Method::POST, http::Method::OPTIONS]);
1201
1202 let allowed_headers = AllowHeaders::list([
1203 header::AUTHORIZATION,
1204 header::CONTENT_TYPE,
1205 header::ACCEPT,
1206 "x-requested-with".parse().unwrap(),
1207 ]);
1208
1209 CorsLayer::new()
1210 .allow_origin(allowed)
1211 .allow_methods(allowed_methods)
1212 .allow_headers(allowed_headers)
1213 .allow_credentials(true)
1214 .max_age(std::time::Duration::from_secs(86400))
1215}