1use axum::{
2 extract::{Form, Json, Path, Query, State},
3 http::{header, HeaderMap, StatusCode},
4 response::{Html, IntoResponse, Redirect},
5 routing::{get, post},
6 Router,
7};
8use base64::{engine::general_purpose, Engine};
9use chrono::{Duration, Utc};
10use http::HeaderValue;
11use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header};
12use jsonwebtoken::{
13 jwk::{CommonParameters, Jwk},
14 Validation,
15};
16use once_cell::sync::Lazy;
17use rsa::pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey};
18use rsa::traits::PublicKeyParts;
19use rsa::RsaPrivateKey;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22use sha2::Digest;
23use std::{
24 collections::{HashMap, HashSet},
25 net::SocketAddr,
26 sync::{Arc, RwLock},
27};
28use tokio::{net::TcpListener, task::JoinHandle};
29use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
30use uuid::Uuid;
31
32pub const KID: &str = "oauth-by-rustmcp";
33
34use serde_json::Value;
36
37#[derive(Debug, Clone)]
38pub struct IssuerConfig {
39 pub scheme: String,
40 pub host: String,
41 pub port: u16,
42
43 pub scopes_supported: HashSet<String>,
45 pub claims_supported: Vec<String>,
46 pub grant_types_supported: HashSet<String>,
47 pub response_types_supported: HashSet<String>,
48 pub token_endpoint_auth_methods_supported: HashSet<String>,
49 pub code_challenge_methods_supported: HashSet<String>,
50 pub subject_types_supported: Vec<String>,
51 pub id_token_signing_alg_values_supported: Vec<String>,
52 pub generate_client_secret_for_dcr: bool,
53 pub allowed_origins: Vec<String>,
54}
55
56impl Default for IssuerConfig {
57 fn default() -> Self {
58 let mut scopes = HashSet::new();
59 scopes.extend([
60 "openid".into(),
61 "profile".into(),
62 "email".into(),
63 "offline_access".into(),
64 "address".into(),
65 "phone".into(),
66 ]);
67
68 let mut grants = HashSet::new();
69 grants.extend([
70 "authorization_code".into(),
71 "refresh_token".into(),
72 "client_credentials".into(),
73 ]);
74
75 let mut auth_methods = HashSet::new();
76 auth_methods.extend([
77 "client_secret_basic".into(),
78 "client_secret_post".into(),
79 "none".into(),
80 "private_key_jwt".into(),
81 ]);
82
83 Self {
84 scheme: "http".into(),
85 host: "localhost".into(),
86 port: 0, scopes_supported: scopes,
88 claims_supported: vec![
89 "sub".into(),
90 "name".into(),
91 "given_name".into(),
92 "family_name".into(),
93 "email".into(),
94 "email_verified".into(),
95 "picture".into(),
96 "locale".into(),
97 ],
98 generate_client_secret_for_dcr: true,
99 grant_types_supported: grants,
100 response_types_supported: ["code".into(), "token".into(), "id_token".into()].into(),
101 token_endpoint_auth_methods_supported: auth_methods,
102 code_challenge_methods_supported: ["plain".into(), "S256".into()].into(),
103 subject_types_supported: vec!["public".into()],
104 id_token_signing_alg_values_supported: vec!["RS256".into()],
105 allowed_origins: vec![
106 "http://localhost:3001".to_string(),
107 "http://localhost:8080".to_string(),
108 "http://localhost:6274".to_string(),
109 ],
110 }
111 }
112}
113
114impl IssuerConfig {
115 pub fn to_discovery_document(&self, issuer: String) -> Value {
116 let iss = issuer;
117 json!({
118 "issuer": iss,
119 "authorization_endpoint": format!("{}/authorize", iss),
120 "token_endpoint": format!("{}/token", iss),
121 "userinfo_endpoint": format!("{}/userinfo", iss),
122 "jwks_uri": format!("{}/.well-known/jwks.json", iss),
123 "registration_endpoint": format!("{}/register", iss),
124 "revocation_endpoint": format!("{}/revoke", iss),
125 "introspection_endpoint": format!("{}/introspect", iss),
126 "scopes_supported": self.scopes_supported.iter().collect::<Vec<_>>(),
127 "claims_supported": &self.claims_supported,
128 "grant_types_supported": self.grant_types_supported.iter().collect::<Vec<_>>(),
129 "response_types_supported": self.response_types_supported.iter().collect::<Vec<_>>(),
130 "token_endpoint_auth_methods_supported": self.token_endpoint_auth_methods_supported.iter().collect::<Vec<_>>(),
131 "code_challenge_methods_supported": self.code_challenge_methods_supported.iter().collect::<Vec<_>>(),
132 "subject_types_supported": &self.subject_types_supported,
133 "id_token_signing_alg_values_supported": &self.id_token_signing_alg_values_supported,
134 })
135 }
136
137 pub fn validate_scope(&self, scope: &str) -> Result<String, String> {
138 let requested: HashSet<_> = scope.split_whitespace().map(|s| s.to_string()).collect();
139 let unknown: Vec<_> = requested
140 .difference(&self.scopes_supported)
141 .cloned()
142 .collect();
143 if unknown.is_empty() {
144 Ok(scope.to_string())
145 } else {
146 Err(format!("invalid_scope: {}", unknown.join(" ")))
147 }
148 }
149
150 pub fn validate_grant_type(&self, grant: &str) -> bool {
151 self.grant_types_supported.contains(grant)
152 }
153}
154
155#[derive(Clone)]
181pub struct AppState {
182 pub config: Arc<IssuerConfig>,
183 pub base_url: String,
184 pub clients: Arc<RwLock<HashMap<String, Client>>>,
185 pub codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
186 pub tokens: Arc<RwLock<HashMap<String, Token>>>,
187 pub refresh_tokens: Arc<RwLock<HashMap<String, Token>>>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct Client {
192 pub client_id: String,
193 pub client_secret: Option<String>,
194 pub redirect_uris: Vec<String>,
195 pub grant_types: Vec<String>,
196 pub response_types: Vec<String>,
197 pub scope: String,
198 pub token_endpoint_auth_method: String,
199 pub client_name: Option<String>,
200 pub client_uri: Option<String>,
201 pub logo_uri: Option<String>,
202 pub contacts: Vec<String>,
203 pub policy_uri: Option<String>,
204 pub tos_uri: Option<String>,
205 pub jwks: Option<serde_json::Value>,
206 pub jwks_uri: Option<String>,
207 pub software_id: Option<String>,
208 pub software_version: Option<String>,
209 pub registration_access_token: Option<String>,
210 pub registration_client_uri: Option<String>,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct AuthorizationCode {
215 pub code: String,
216 pub client_id: String,
217 pub redirect_uri: String,
218 pub scope: String,
219 pub expires_at: chrono::DateTime<Utc>,
220 pub code_challenge: Option<String>,
221 pub code_challenge_method: Option<String>,
222 pub user_id: String,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct Token {
227 pub access_token: String,
228 pub refresh_token: Option<String>,
229 pub client_id: String,
230 pub scope: String,
231 pub expires_at: chrono::DateTime<Utc>,
232 pub user_id: String,
233 pub revoked: bool,
234}
235
236impl AppState {
237 pub fn new(config: IssuerConfig) -> Self {
238 let base_url = format!("{}://{}:{}", config.scheme, config.host, config.port);
239 Self {
240 config: Arc::new(config),
241 clients: Arc::new(RwLock::new(HashMap::new())),
242 codes: Arc::new(RwLock::new(HashMap::new())),
243 tokens: Arc::new(RwLock::new(HashMap::new())),
244 refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
245 base_url,
246 }
247 }
248
249 pub fn issuer(&self) -> &str {
250 self.base_url.as_str()
251 }
252
253 pub fn register_client(
254 &self,
255 metadata: serde_json::Value,
256 ) -> Result<Client, (StatusCode, Json<serde_json::Value>)> {
257 let requested_scope = metadata
258 .get("scope")
259 .and_then(|v| v.as_str())
260 .unwrap_or("openid");
261
262 self.config
263 .validate_scope(requested_scope)
264 .map_err(|err| (StatusCode::BAD_REQUEST, Json(json!({"error": err}))))?;
265
266 let client_id = Uuid::new_v4().to_string();
267
268 let client_secret = if self.config.generate_client_secret_for_dcr
269 || metadata
270 .get("token_endpoint_auth_method")
271 .and_then(|v| v.as_str())
272 != Some("none")
273 {
274 Some(generate_token())
275 } else {
276 None
277 };
278
279 let redirect_uris = metadata
280 .get("redirect_uris")
281 .and_then(|v| v.as_array())
282 .map(|arr| {
283 arr.iter()
284 .filter_map(|u| u.as_str().map(|s| s.to_string()))
285 .collect::<Vec<String>>()
286 })
287 .unwrap_or_default();
288
289 if redirect_uris.is_empty()
290 && metadata.get("grant_types").map(|v| {
291 v.as_array()
292 .map(|a| a.contains(&json!("client_credentials")))
293 }) != Some(Some(true))
294 {
295 return Err((
296 StatusCode::BAD_REQUEST,
297 Json(json!({"error": "redirect_uris required"})),
298 ));
299 }
300
301 let client = Client {
302 client_id: client_id.clone(),
303 client_secret: client_secret.clone(),
304 redirect_uris,
305 grant_types: metadata
306 .get("grant_types")
307 .and_then(|v| v.as_array())
308 .map(|arr| {
309 arr.iter()
310 .filter_map(|v| v.as_str().map(|s| s.to_string()))
311 .collect()
312 })
313 .unwrap_or_else(|| vec!["authorization_code".to_string()]),
314 response_types: metadata
315 .get("response_types")
316 .and_then(|v| v.as_array())
317 .map(|arr| {
318 arr.iter()
319 .filter_map(|v| v.as_str().map(|s| s.to_string()))
320 .collect()
321 })
322 .unwrap_or_else(|| vec!["code".to_string()]),
323 scope: metadata
324 .get("scope")
325 .and_then(|v| v.as_str())
326 .unwrap_or("")
327 .to_string(),
328 token_endpoint_auth_method: metadata
329 .get("token_endpoint_auth_method")
330 .and_then(|v| v.as_str())
331 .unwrap_or("client_secret_basic")
332 .to_string(),
333 client_name: metadata
334 .get("client_name")
335 .and_then(|v| v.as_str())
336 .map(|s| s.to_string()),
337 client_uri: metadata
338 .get("client_uri")
339 .and_then(|v| v.as_str())
340 .map(|s| s.to_string()),
341 logo_uri: metadata
342 .get("logo_uri")
343 .and_then(|v| v.as_str())
344 .map(|s| s.to_string()),
345 contacts: metadata
346 .get("contacts")
347 .and_then(|v| v.as_array())
348 .map(|arr| {
349 arr.iter()
350 .filter_map(|v| v.as_str().map(|s| s.to_string()))
351 .collect()
352 })
353 .unwrap_or_default(),
354 policy_uri: metadata
355 .get("policy_uri")
356 .and_then(|v| v.as_str())
357 .map(|s| s.to_string()),
358 tos_uri: metadata
359 .get("tos_uri")
360 .and_then(|v| v.as_str())
361 .map(|s| s.to_string()),
362 jwks: metadata.get("jwks").cloned(),
363 jwks_uri: metadata
364 .get("jwks_uri")
365 .and_then(|v| v.as_str())
366 .map(|s| s.to_string()),
367 software_id: metadata
368 .get("software_id")
369 .and_then(|v| v.as_str())
370 .map(|s| s.to_string()),
371 software_version: metadata
372 .get("software_version")
373 .and_then(|v| v.as_str())
374 .map(|s| s.to_string()),
375 registration_access_token: None,
376 registration_client_uri: Some(format!("{}/register/{}", self.issuer(), client_id)),
377 };
378
379 self.clients
380 .write()
381 .unwrap()
382 .insert(client_id.clone(), client.clone());
383
384 Ok(client)
385 }
386
387 pub fn generate_token(
388 &self,
389 client: &Client,
390 options: crate::testkit::JwtOptions,
391 ) -> Result<Token, jsonwebtoken::errors::Error> {
392 let user_id = options.user_id.clone();
393 let jwt = self.generate_jwt(client, options)?;
394 let refresh_token = generate_token();
395 let token = Token {
396 access_token: jwt.clone(),
397 refresh_token: Some(refresh_token),
398 client_id: client.client_id.clone(),
399 scope: client.scope.clone(),
400 expires_at: Utc::now() + Duration::hours(1),
401 user_id,
402 revoked: false,
403 };
404
405 self.tokens
406 .write()
407 .unwrap()
408 .insert(jwt.clone(), token.clone());
409 Ok(token)
410 }
411 pub fn generate_jwt(
412 &self,
413 client: &Client,
414 options: crate::testkit::JwtOptions,
415 ) -> Result<String, jsonwebtoken::errors::Error> {
416 let scope = options.scope.unwrap_or_else(|| client.scope.clone());
417 issue_jwt(
418 self.issuer(),
419 &client.client_id,
420 &options.user_id,
421 &scope,
422 options.expires_in,
423 )
424 }
425
426 pub fn router(self) -> Router {
427 let cors = build_cors_layer(&self.config);
428 Router::new()
429 .route(
430 "/.well-known/openid-configuration",
431 get(well_known_openid_configuration),
432 )
433 .route("/.well-known/jwks.json", get(jwks))
434 .route("/register", post(register_client))
435 .route("/register/{client_id}", get(get_client))
436 .route("/authorize", get(authorize))
437 .route("/token", post(token_endpoint))
438 .route("/introspect", post(introspect))
439 .route("/revoke", post(revoke))
440 .route("/userinfo", get(userinfo))
441 .route("/error", get(error_page))
442 .with_state(self)
443 .layer(cors)
444 }
445
446 pub async fn start(mut self) -> (SocketAddr, JoinHandle<()>) {
447 let port = self.config.port;
448 let addr = SocketAddr::from(([127, 0, 0, 1], port));
449 let listener = TcpListener::bind(addr).await.unwrap();
450 let local_addr = listener.local_addr().unwrap();
451 let base_url = format!(
452 "{}://{}:{}",
453 self.config.scheme,
454 self.config.host,
455 local_addr.port()
456 );
457 self.base_url = base_url;
458
459 let router = self.router();
460 let handle = tokio::spawn(async move {
461 axum::serve(listener, router).await.unwrap();
462 });
463 (local_addr, handle)
464 }
465}
466
467fn generate_code() -> String {
468 Uuid::new_v4().to_string()[..20].to_string()
469}
470
471fn generate_token() -> String {
472 format!("tok_{}", Uuid::new_v4().to_string().replace("-", ""))
473}
474
475fn issue_jwt(
476 issuer: &str,
477 client_id: &str,
478 user_id: &str,
479 requested_scope: &str,
480 expires_in: i64,
481) -> Result<String, jsonwebtoken::errors::Error> {
482 let iat = Utc::now().timestamp() as usize;
483 let exp = (Utc::now() + Duration::seconds(expires_in)).timestamp() as usize;
484
485 let scopes: Vec<&str> = requested_scope.split_whitespace().collect();
487
488 let claims = Claims {
490 iss: issuer.to_string(),
491 sub: user_id.to_string(),
492 aud: client_id.to_string(),
493 exp,
494 iat,
495 scope: Some(scopes.join(" ")), auth_time: Some(iat),
497 typ: "Bearer".to_string(), azp: Some(client_id.to_string()),
500 sid: Some(format!("sid-{}", Uuid::new_v4())),
501 jti: Uuid::new_v4().to_string(),
502 };
503
504 let mut header = Header::new(Algorithm::RS256);
506 header.typ = Some("JWT".to_string());
507 header.kid = Some(KID.to_string());
508
509 encode(&header, &claims, &KEYS.encoding)
511}
512
513pub static KEYS: Lazy<Keys> = Lazy::new(|| {
515 let mut rng = rand::thread_rng();
516 let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate key");
517 let public_key = private_key.to_public_key();
518
519 let private_pem = private_key
521 .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
522 .expect("failed to encode private key")
523 .to_string();
524
525 let public_pem = public_key
526 .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
527 .expect("failed to encode public key")
528 .to_string();
529
530 let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap();
531 let decoding_key = DecodingKey::from_rsa_pem(public_pem.as_bytes()).unwrap();
532
533 Keys {
534 encoding: encoding_key,
535 decoding: decoding_key,
536 public_pem, }
538});
539
540static JWKS_JSON: Lazy<serde_json::Value> = Lazy::new(|| {
541 let public_key = rsa::RsaPublicKey::from_public_key_pem(&KEYS.public_pem)
542 .expect("Failed to parse stored public key");
543
544 let jwk = Jwk {
545 common: CommonParameters {
546 key_algorithm: Some(jsonwebtoken::jwk::KeyAlgorithm::RS256),
547 key_id: Some(KID.to_string()),
548 ..Default::default()
549 },
550 algorithm: jsonwebtoken::jwk::AlgorithmParameters::RSA(
551 jsonwebtoken::jwk::RSAKeyParameters {
552 n: general_purpose::URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be()),
553 e: general_purpose::URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be()),
554 key_type: jsonwebtoken::jwk::RSAKeyType::RSA,
555 },
556 ),
557 };
558
559 json!({ "keys": [jwk] })
560});
561
562#[allow(unused)]
563pub struct Keys {
564 pub encoding: EncodingKey,
565 pub decoding: DecodingKey,
566 pub public_pem: String,
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570struct AccessToken {
571 access_token: String,
572 token_type: String,
573 expires_in: i64,
574 scope: Option<String>,
575 refresh_token: Option<String>,
576 id_token: Option<String>,
577}
578
579#[derive(Debug, Clone, Serialize, Deserialize)]
580struct Claims {
581 iss: String,
582 sub: String,
583 aud: String,
584 exp: usize,
585 iat: usize,
586 scope: Option<String>,
587 auth_time: Option<usize>,
588 typ: String, azp: Option<String>, sid: Option<String>, jti: String, }
593
594async fn well_known_openid_configuration(State(state): State<AppState>) -> impl IntoResponse {
596 let discovery = state.config.to_discovery_document(state.base_url);
597 (StatusCode::OK, Json(discovery))
598}
599
600async fn jwks() -> impl IntoResponse {
601 (StatusCode::OK, Json(JWKS_JSON.clone()))
602}
603
604async fn register_client(
605 State(state): State<AppState>,
606 Json(metadata): Json<serde_json::Value>,
607) -> impl IntoResponse {
608 let requested_scope = metadata
609 .get("scope")
610 .and_then(|v| v.as_str())
611 .unwrap_or("openid");
612
613 match state.config.validate_scope(requested_scope) {
614 Ok(_) => { }
615 Err(e) => {
616 return (
617 StatusCode::BAD_REQUEST,
618 Json(json!({ "error": "invalid_scope", "error_description": e })),
619 );
620 }
621 };
622
623 let client_id = Uuid::new_v4().to_string();
624
625 let client_secret = if state.config.generate_client_secret_for_dcr
626 || metadata
627 .get("token_endpoint_auth_method")
628 .and_then(|v| v.as_str())
629 != Some("none")
630 {
631 Some(generate_token())
632 } else {
633 None
634 };
635
636 let redirect_uris = metadata
637 .get("redirect_uris")
638 .and_then(|v| v.as_array())
639 .map(|arr| {
640 arr.iter()
641 .filter_map(|u| u.as_str().map(|s| s.to_string()))
642 .collect::<Vec<String>>()
643 })
644 .unwrap_or_default();
645
646 if redirect_uris.is_empty()
647 && metadata.get("grant_types").map(|v| {
648 v.as_array()
649 .map(|a| a.contains(&json!("client_credentials")))
650 }) != Some(Some(true))
651 {
652 return (
653 StatusCode::BAD_REQUEST,
654 Json(json!({"error": "redirect_uris required"})),
655 );
656 }
657
658 let client = Client {
659 client_id: client_id.clone(),
660 client_secret: client_secret.clone(),
661 redirect_uris,
662 grant_types: metadata
663 .get("grant_types")
664 .and_then(|v| v.as_array())
665 .map(|arr| {
666 arr.iter()
667 .filter_map(|v| v.as_str().map(|s| s.to_string()))
668 .collect()
669 })
670 .unwrap_or_else(|| vec!["authorization_code".to_string()]),
671 response_types: metadata
672 .get("response_types")
673 .and_then(|v| v.as_array())
674 .map(|arr| {
675 arr.iter()
676 .filter_map(|v| v.as_str().map(|s| s.to_string()))
677 .collect()
678 })
679 .unwrap_or_else(|| vec!["code".to_string()]),
680 scope: metadata
681 .get("scope")
682 .and_then(|v| v.as_str())
683 .unwrap_or("")
684 .to_string(),
685 token_endpoint_auth_method: metadata
686 .get("token_endpoint_auth_method")
687 .and_then(|v| v.as_str())
688 .unwrap_or("client_secret_basic")
689 .to_string(),
690 client_name: metadata
691 .get("client_name")
692 .and_then(|v| v.as_str())
693 .map(|s| s.to_string()),
694 client_uri: metadata
695 .get("client_uri")
696 .and_then(|v| v.as_str())
697 .map(|s| s.to_string()),
698 logo_uri: metadata
699 .get("logo_uri")
700 .and_then(|v| v.as_str())
701 .map(|s| s.to_string()),
702 contacts: metadata
703 .get("contacts")
704 .and_then(|v| v.as_array())
705 .map(|arr| {
706 arr.iter()
707 .filter_map(|v| v.as_str().map(|s| s.to_string()))
708 .collect()
709 })
710 .unwrap_or_default(),
711 policy_uri: metadata
712 .get("policy_uri")
713 .and_then(|v| v.as_str())
714 .map(|s| s.to_string()),
715 tos_uri: metadata
716 .get("tos_uri")
717 .and_then(|v| v.as_str())
718 .map(|s| s.to_string()),
719 jwks: metadata.get("jwks").cloned(),
720 jwks_uri: metadata
721 .get("jwks_uri")
722 .and_then(|v| v.as_str())
723 .map(|s| s.to_string()),
724 software_id: metadata
725 .get("software_id")
726 .and_then(|v| v.as_str())
727 .map(|s| s.to_string()),
728 software_version: metadata
729 .get("software_version")
730 .and_then(|v| v.as_str())
731 .map(|s| s.to_string()),
732 registration_access_token: None,
733 registration_client_uri: Some(format!("{}/register/{}", state.issuer(), client_id)),
734 };
735
736 state
737 .clients
738 .write()
739 .unwrap()
740 .insert(client_id.clone(), client.clone());
741
742 let response = json!({
743 "client_id": client.client_id,
744 "client_secret": client.client_secret,
745 "client_id_issued_at": Utc::now().timestamp(),
746 "registration_client_uri": client.registration_client_uri,
747 "registration_access_token": Uuid::new_v4().to_string(),
748 "redirect_uris": client.redirect_uris,
749 "grant_types": client.grant_types,
750 "response_types": client.response_types,
751 "scope": client.scope,
752 "token_endpoint_auth_method": client.token_endpoint_auth_method
753 });
754
755 (StatusCode::CREATED, Json(response))
756}
757
758async fn get_client(
759 State(state): State<AppState>,
760 Path(client_id): Path<String>,
761) -> impl IntoResponse {
762 if let Some(client) = state.clients.read().unwrap().get(&client_id) {
763 let response = json!({
764 "client_id": client.client_id,
765 "client_name": client.client_name,
766 "redirect_uris": client.redirect_uris,
767 "grant_types": client.grant_types,
768 "scope": client.scope
769 });
770 (StatusCode::OK, Json(response))
771 } else {
772 (
773 StatusCode::NOT_FOUND,
774 Json(json!({"error": "client not found"})),
775 )
776 }
777}
778
779#[derive(Deserialize)]
780struct AuthorizeQuery {
781 response_type: String,
782 client_id: String,
783 redirect_uri: Option<String>,
784 scope: Option<String>,
785 state: Option<String>,
786 code_challenge: Option<String>,
787 code_challenge_method: Option<String>,
788}
789
790async fn authorize(
791 State(state): State<AppState>,
792 Query(params): Query<AuthorizeQuery>,
793) -> impl IntoResponse {
794 let clients = state.clients.read().unwrap();
795 let client = match clients.get(¶ms.client_id) {
796 Some(c) => c,
797 None => {
798 return Redirect::to(&format!(
799 "/error?error=invalid_client&state={}",
800 params.state.as_deref().unwrap_or("")
801 ))
802 .into_response();
803 }
804 };
805
806 if params.response_type != "code" {
807 return Redirect::to(&format!(
808 "/error?error=unsupported_response_type&state={}",
809 params.state.as_deref().unwrap_or("")
810 ))
811 .into_response();
812 }
813
814 let redirect_uri = match ¶ms.redirect_uri {
815 Some(uri) => {
816 if !client.redirect_uris.contains(uri) {
817 return Redirect::to(&format!(
818 "/error?error=invalid_request&state={}",
819 params.state.as_deref().unwrap_or("")
820 ))
821 .into_response();
822 }
823 uri.clone()
824 }
825 None => client.redirect_uris.first().unwrap().clone(),
826 };
827
828 let code = generate_code();
829
830 let requested_scopes: HashSet<String> = params
832 .scope
833 .clone()
834 .unwrap_or_default()
835 .split_whitespace()
836 .map(|s| s.to_string())
837 .collect();
838
839 let registered_scopes: HashSet<String> = client
840 .scope
841 .split_whitespace()
842 .map(|s| s.to_string())
843 .collect();
844
845 let granted_scopes: Vec<String> = requested_scopes
846 .intersection(®istered_scopes)
847 .cloned()
848 .collect();
849
850 let final_scope = granted_scopes.join(" ");
851
852 let auth_code = AuthorizationCode {
853 code: code.clone(),
854 client_id: params.client_id.clone(),
855 redirect_uri: redirect_uri.clone(),
856 scope: final_scope, expires_at: Utc::now() + Duration::minutes(10),
858 code_challenge: params.code_challenge.clone(),
859 code_challenge_method: params.code_challenge_method.clone(),
860 user_id: "test-user-123".to_string(),
861 };
862
863 state.codes.write().unwrap().insert(code.clone(), auth_code);
864
865 let redirect_url = format!(
866 "{}?code={}&state={}",
867 redirect_uri,
868 code,
869 params.state.as_deref().unwrap_or("")
870 );
871
872 Redirect::to(&redirect_url).into_response()
873}
874
875#[derive(Deserialize)]
876struct TokenRequest {
877 grant_type: String,
878 code: Option<String>,
879 _redirect_uri: Option<String>,
880 client_id: Option<String>,
881 _client_secret: Option<String>,
882 refresh_token: Option<String>,
883 code_verifier: Option<String>,
884 scope: Option<String>,
885}
886
887async fn token_endpoint(
888 State(state): State<AppState>,
889 _headers: HeaderMap,
890 Form(form): Form<TokenRequest>,
891) -> impl IntoResponse {
892 if form.grant_type == "authorization_code" {
893 let code = form.code.as_deref().unwrap_or("");
894 let code_obj = match state.codes.write().unwrap().remove(code) {
895 Some(c) => c,
896 None => {
897 return (
898 StatusCode::BAD_REQUEST,
899 Json(json!({"error": "invalid_grant"})),
900 )
901 .into_response();
902 }
903 };
904
905 if code_obj.expires_at < Utc::now() {
906 return (
907 StatusCode::BAD_REQUEST,
908 Json(json!({"error": "invalid_grant"})),
909 )
910 .into_response();
911 }
912
913 if let (Some(challenge), Some(verifier)) = (&code_obj.code_challenge, &form.code_verifier) {
914 let method = code_obj.code_challenge_method.as_deref().unwrap_or("plain");
915 let computed = if method == "S256" {
916 general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(verifier.as_bytes()))
917 } else {
918 verifier.clone()
919 };
920 if computed != *challenge {
921 return (
922 StatusCode::BAD_REQUEST,
923 Json(json!({"error": "invalid_grant"})),
924 )
925 .into_response();
926 }
927 }
928
929 let refresh_token = generate_token();
931
932 let jwt = issue_jwt(
933 state.issuer(),
934 &code_obj.client_id,
935 &code_obj.user_id,
936 &code_obj.scope,
937 3600,
938 )
939 .unwrap();
940
941 let token = Token {
942 access_token: jwt.clone(),
943 refresh_token: Some(refresh_token.clone()),
944 client_id: code_obj.client_id.clone(),
945 scope: code_obj.scope.clone(),
946 expires_at: Utc::now() + Duration::hours(1),
947 user_id: code_obj.user_id.clone(),
948 revoked: false,
949 };
950
951 state
952 .tokens
953 .write()
954 .unwrap()
955 .insert(jwt.clone(), token.clone());
956 state
957 .refresh_tokens
958 .write()
959 .unwrap()
960 .insert(refresh_token.clone(), token);
961
962 let response = json!({
963 "access_token": jwt,
964 "token_type": "Bearer",
965 "expires_in": 3600,
966 "refresh_token": refresh_token,
967 "scope": code_obj.scope
968 });
969
970 (StatusCode::OK, Json(response)).into_response()
971 } else if form.grant_type == "refresh_token" {
972 let rt = form.refresh_token.as_deref().unwrap_or("");
973 let mut guard = state.refresh_tokens.write().unwrap();
974 let token = match guard.get_mut(rt) {
975 Some(t) => t,
976 None => {
977 return (
978 StatusCode::BAD_REQUEST,
979 Json(json!({"error": "invalid_grant"})),
980 )
981 .into_response();
982 }
983 };
984
985 if token.revoked {
986 return (
987 StatusCode::BAD_REQUEST,
988 Json(json!({"error": "invalid_grant"})),
989 )
990 .into_response();
991 }
992
993 let new_access_token = issue_jwt(
994 state.issuer(),
995 &token.client_id,
996 &token.user_id,
997 &token.scope,
998 3600,
999 )
1000 .unwrap();
1001 let new_refresh_token = generate_token();
1002
1003 let new_token = Token {
1004 access_token: new_access_token.clone(),
1005 refresh_token: Some(new_refresh_token.clone()),
1006 client_id: token.client_id.clone(),
1007 scope: token.scope.clone(),
1008 expires_at: Utc::now() + Duration::hours(1),
1009 user_id: token.user_id.clone(),
1010 revoked: false,
1011 };
1012
1013 state
1014 .tokens
1015 .write()
1016 .unwrap()
1017 .insert(new_access_token.clone(), new_token.clone());
1018 state
1019 .refresh_tokens
1020 .write()
1021 .unwrap()
1022 .insert(new_refresh_token.clone(), new_token);
1023
1024 token.revoked = true;
1025
1026 let response = json!({
1027 "access_token": new_access_token,
1028 "token_type": "Bearer",
1029 "expires_in": 3600,
1030 "refresh_token": new_refresh_token,
1031 "scope": token.scope
1032 });
1033
1034 (StatusCode::OK, Json(response)).into_response()
1035 } else if form.grant_type == "client_credentials" {
1036 let client_id = form.client_id.as_deref().unwrap_or("");
1037 let clients = state.clients.read().unwrap();
1038 let client = match clients.get(client_id) {
1039 Some(c) => c,
1040 None => {
1041 return (
1042 StatusCode::BAD_REQUEST,
1043 Json(json!({"error": "invalid_client"})),
1044 )
1045 .into_response();
1046 }
1047 };
1048
1049 let requested_scopes: HashSet<String> = form
1051 .scope
1052 .as_deref()
1053 .unwrap_or("")
1054 .split_whitespace()
1055 .map(|s| s.to_string())
1056 .collect();
1057
1058 if let Some(requested_scope) = form.scope.as_deref() {
1059 if let Err(e) = state.config.validate_scope(requested_scope) {
1061 return (
1062 StatusCode::BAD_REQUEST,
1063 Json(json!({
1064 "error": "invalid_scope",
1065 "error_description": e
1066 })),
1067 )
1068 .into_response();
1069 }
1070
1071 let client_scopes: HashSet<_> = client.scope.split_whitespace().collect();
1073 let requested_scopes: HashSet<_> = requested_scope.split_whitespace().collect();
1074
1075 let not_permitted: Vec<_> = requested_scopes
1076 .difference(&client_scopes)
1077 .cloned()
1078 .collect();
1079 if !not_permitted.is_empty() {
1080 return (StatusCode::BAD_REQUEST, Json(json!({
1081 "error": "invalid_scope",
1082 "error_description": format!("Client not authorized for scopes: {}", not_permitted.join(" "))
1083 }))).into_response();
1084 }
1085 }
1086
1087 let registered_scopes: HashSet<String> = client
1089 .scope
1090 .split_whitespace()
1091 .map(|s| s.to_string())
1092 .collect();
1093
1094 let granted_scopes: Vec<String> = requested_scopes
1096 .intersection(®istered_scopes)
1097 .cloned()
1098 .collect();
1099
1100 if granted_scopes.is_empty() && !requested_scopes.is_empty() {
1102 return (
1103 StatusCode::BAD_REQUEST,
1104 Json(json!({
1105 "error": "invalid_scope",
1106 "error_description": "Requested scopes not allowed for this client"
1107 })),
1108 )
1109 .into_response();
1110 }
1111
1112 let final_scope = if requested_scopes.is_empty() {
1114 client.scope.clone()
1116 } else {
1117 granted_scopes.join(" ")
1118 };
1119
1120 let access_token =
1122 issue_jwt(state.issuer(), client_id, "client", &final_scope, 3600).unwrap();
1123
1124 let response = json!({
1125 "access_token": access_token,
1126 "token_type": "Bearer",
1127 "expires_in": 3600,
1128 "scope": final_scope
1129 });
1130
1131 (StatusCode::OK, Json(response)).into_response()
1132 } else {
1133 (
1134 StatusCode::BAD_REQUEST,
1135 Json(json!({"error": "unsupported_grant_type"})),
1136 )
1137 .into_response()
1138 }
1139}
1140
1141async fn introspect(
1142 State(state): State<AppState>,
1143 Form(form): Form<HashMap<String, String>>,
1144) -> impl IntoResponse {
1145 let token = match form.get("token") {
1146 Some(t) => t,
1147 None => {
1148 return (
1149 StatusCode::BAD_REQUEST,
1150 Json(json!({"error": "invalid_request"})),
1151 )
1152 .into_response()
1153 }
1154 };
1155
1156 let stored_token = state.tokens.read().unwrap().get(token).cloned();
1158
1159 let mut validation = Validation::new(Algorithm::RS256);
1161 validation.validate_exp = false; validation.required_spec_claims.clear(); validation.validate_aud = false;
1164
1165 match jsonwebtoken::decode::<Claims>(token, &KEYS.decoding, &validation) {
1166 Ok(token_data) => {
1167 let claims = token_data.claims;
1168
1169 let is_expired = Utc::now().timestamp() > claims.exp as i64;
1171 let is_revoked = stored_token.map(|t| t.revoked).unwrap_or(false);
1172 let active = !is_revoked && !is_expired;
1173
1174 let mut response = json!({
1175 "active": active,
1176 "scope": claims.scope,
1177 "client_id": claims.aud,
1178 "sub": claims.sub,
1179 "iss": claims.iss,
1180 "aud": claims.aud,
1181 "iat": claims.iat,
1182 "exp": claims.exp,
1183 "jti": claims.jti,
1184 "token_type": "Bearer",
1185 "azp": claims.azp,
1186 "auth_time": claims.auth_time,
1187 "sid": claims.sid,
1188 });
1189
1190 if claims.scope.is_none() {
1192 response.as_object_mut().unwrap().remove("scope");
1193 }
1194
1195 (StatusCode::OK, Json(response)).into_response()
1196 }
1197 Err(err) => {
1198 println!(">>> err {:?} ", err);
1199
1200 (StatusCode::OK, Json(json!({"active": false}))).into_response()
1202 }
1203 }
1204}
1205
1206async fn revoke(
1207 State(state): State<AppState>,
1208 Form(form): Form<HashMap<String, String>>,
1209) -> impl IntoResponse {
1210 let token = form.get("token").cloned().unwrap_or_default();
1211 if let Some(t) = state.tokens.write().unwrap().get_mut(&token) {
1212 t.revoked = true;
1213 }
1214 if let Some(t) = state.refresh_tokens.write().unwrap().get_mut(&token) {
1215 t.revoked = true;
1216 }
1217 (StatusCode::OK, Json(json!({}))).into_response()
1218}
1219
1220async fn userinfo(headers: HeaderMap, State(state): State<AppState>) -> impl IntoResponse {
1221 let auth = headers.get("Authorization").and_then(|v| v.to_str().ok());
1222 if let Some(auth) = auth {
1223 if let Some(token) = auth.strip_prefix("Bearer ") {
1224 if let Some(t) = state.tokens.read().unwrap().get(token) {
1225 if t.revoked || t.expires_at < Utc::now() {
1226 return (
1227 StatusCode::UNAUTHORIZED,
1228 Json(json!({"error": "invalid_token"})),
1229 )
1230 .into_response();
1231 }
1232 let response = json!({
1233 "sub": t.user_id,
1234 "name": "Test User",
1235 "email": "test@example.com",
1236 "picture": "https://example.com/avatar.jpg"
1237 });
1238 return (StatusCode::OK, Json(response)).into_response();
1239 }
1240 }
1241 }
1242 (
1243 StatusCode::UNAUTHORIZED,
1244 Json(json!({"error": "invalid_token"})),
1245 )
1246 .into_response()
1247}
1248
1249async fn error_page(Query(params): Query<HashMap<String, String>>) -> Html<String> {
1250 let error = params.get("error").map(|s| s.as_str()).unwrap_or("unknown");
1251 Html(format!("<h1>OAuth Error: {}</h1>", error))
1252}
1253
1254fn build_cors_layer(config: &IssuerConfig) -> CorsLayer {
1255 let allowed_origins: Vec<HeaderValue> = config
1256 .allowed_origins
1257 .iter()
1258 .filter_map(|s| s.parse().ok())
1259 .collect();
1260
1261 let allowed = if allowed_origins.is_empty() {
1262 AllowOrigin::any()
1263 } else {
1264 AllowOrigin::list(allowed_origins)
1265 };
1266
1267 let allowed_methods =
1268 AllowMethods::list([http::Method::GET, http::Method::POST, http::Method::OPTIONS]);
1269
1270 let allowed_headers = AllowHeaders::list([
1271 header::AUTHORIZATION,
1272 header::CONTENT_TYPE,
1273 header::ACCEPT,
1274 "x-requested-with".parse().unwrap(),
1275 ]);
1276
1277 CorsLayer::new()
1278 .allow_origin(allowed)
1279 .allow_methods(allowed_methods)
1280 .allow_headers(allowed_headers)
1281 .allow_credentials(true)
1282 .max_age(std::time::Duration::from_secs(86400))
1283}