oauth2_test_server/
server.rs

1use axum::{
2    extract::{Form, Json, Path, Query, State},
3    http::{header, HeaderMap, StatusCode},
4    response::{Html, IntoResponse, Redirect},
5    routing::{get, post},
6    Router,
7};
8use base64::{engine::general_purpose, Engine};
9use chrono::{Duration, Utc};
10use http::HeaderValue;
11use jsonwebtoken::jwk::{CommonParameters, Jwk};
12use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header};
13use once_cell::sync::Lazy;
14use rsa::pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey};
15use rsa::traits::PublicKeyParts;
16use rsa::RsaPrivateKey;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use sha2::Digest;
20use std::{
21    collections::{HashMap, HashSet},
22    net::SocketAddr,
23    sync::{Arc, RwLock},
24};
25use tokio::{net::TcpListener, task::JoinHandle};
26use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
27use uuid::Uuid;
28
29pub const KID: &str = "oauth-by-rustmcp";
30
31// src/config.rs
32use serde_json::Value;
33
34#[derive(Debug, Clone)]
35pub struct IssuerConfig {
36    pub scheme: String,
37    pub host: String,
38    pub port: u16,
39
40    // OIDC / OAuth capabilities
41    pub scopes_supported: HashSet<String>,
42    pub claims_supported: Vec<String>,
43    pub grant_types_supported: HashSet<String>,
44    pub response_types_supported: HashSet<String>,
45    pub token_endpoint_auth_methods_supported: HashSet<String>,
46    pub code_challenge_methods_supported: HashSet<String>,
47    pub subject_types_supported: Vec<String>,
48    pub id_token_signing_alg_values_supported: Vec<String>,
49    pub generate_client_secret_for_dcr: bool,
50    pub allowed_origins: Vec<String>,
51}
52
53impl Default for IssuerConfig {
54    fn default() -> Self {
55        let mut scopes = HashSet::new();
56        scopes.extend([
57            "openid".into(),
58            "profile".into(),
59            "email".into(),
60            "offline_access".into(),
61            "address".into(),
62            "phone".into(),
63        ]);
64
65        let mut grants = HashSet::new();
66        grants.extend([
67            "authorization_code".into(),
68            "refresh_token".into(),
69            "client_credentials".into(),
70        ]);
71
72        let mut auth_methods = HashSet::new();
73        auth_methods.extend([
74            "client_secret_basic".into(),
75            "client_secret_post".into(),
76            "none".into(),
77            "private_key_jwt".into(),
78        ]);
79
80        Self {
81            scheme: "http".into(),
82            host: "localhost".into(),
83            port: 0, // random
84            scopes_supported: scopes,
85            claims_supported: vec![
86                "sub".into(),
87                "name".into(),
88                "given_name".into(),
89                "family_name".into(),
90                "email".into(),
91                "email_verified".into(),
92                "picture".into(),
93                "locale".into(),
94            ],
95            generate_client_secret_for_dcr: true,
96            grant_types_supported: grants,
97            response_types_supported: ["code".into(), "token".into(), "id_token".into()].into(),
98            token_endpoint_auth_methods_supported: auth_methods,
99            code_challenge_methods_supported: ["plain".into(), "S256".into()].into(),
100            subject_types_supported: vec!["public".into()],
101            id_token_signing_alg_values_supported: vec!["RS256".into()],
102            allowed_origins: vec![
103                "http://localhost:3001".to_string(),
104                "http://localhost:8080".to_string(),
105                "http://localhost:6274".to_string(),
106            ],
107        }
108    }
109}
110
111impl IssuerConfig {
112    pub fn to_discovery_document(&self, issuer: String) -> Value {
113        let iss = issuer;
114        json!({
115            "issuer": iss,
116            "authorization_endpoint": format!("{}/authorize", iss),
117            "token_endpoint": format!("{}/token", iss),
118            "userinfo_endpoint": format!("{}/userinfo", iss),
119            "jwks_uri": format!("{}/.well-known/jwks.json", iss),
120            "registration_endpoint": format!("{}/register", iss),
121            "revocation_endpoint": format!("{}/revoke", iss),
122            "introspection_endpoint": format!("{}/introspect", iss),
123            "scopes_supported": self.scopes_supported.iter().collect::<Vec<_>>(),
124            "claims_supported": &self.claims_supported,
125            "grant_types_supported": self.grant_types_supported.iter().collect::<Vec<_>>(),
126            "response_types_supported": self.response_types_supported.iter().collect::<Vec<_>>(),
127            "token_endpoint_auth_methods_supported": self.token_endpoint_auth_methods_supported.iter().collect::<Vec<_>>(),
128            "code_challenge_methods_supported": self.code_challenge_methods_supported.iter().collect::<Vec<_>>(),
129            "subject_types_supported": &self.subject_types_supported,
130            "id_token_signing_alg_values_supported": &self.id_token_signing_alg_values_supported,
131        })
132    }
133
134    pub fn validate_scope(&self, scope: &str) -> Result<String, String> {
135        let requested: HashSet<_> = scope.split_whitespace().map(|s| s.to_string()).collect();
136        let unknown: Vec<_> = requested
137            .difference(&self.scopes_supported)
138            .cloned()
139            .collect();
140        if unknown.is_empty() {
141            Ok(scope.to_string())
142        } else {
143            Err(format!("invalid_scope: {}", unknown.join(" ")))
144        }
145    }
146
147    pub fn validate_grant_type(&self, grant: &str) -> bool {
148        self.grant_types_supported.contains(grant)
149    }
150}
151
152// #[derive(Clone, Debug)]
153// pub struct Config {
154//     pub scheme: String,
155//     pub host: String,
156//     pub port: u16,
157//     pub generate_client_secret_for_dcr: bool,
158//     pub allowed_origins: Vec<String>,
159// }
160
161// impl Default for Config {
162//     fn default() -> Self {
163//         Self {
164//             scheme: "http".to_string(),
165//             host: "localhost".to_string(),
166//             port: 8090,
167//             generate_client_secret_for_dcr: true,
168//             allowed_origins: vec![
169//                 "http://localhost:3001".to_string(),
170//                 "http://localhost:8080".to_string(),
171//                 "http://localhost:6274".to_string(),
172//             ],
173//         }
174//     }
175// }
176
177#[derive(Clone)]
178pub struct AppState {
179    pub config: Arc<IssuerConfig>,
180    pub base_url: String,
181    pub clients: Arc<RwLock<HashMap<String, Client>>>,
182    pub codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
183    pub tokens: Arc<RwLock<HashMap<String, Token>>>,
184    pub refresh_tokens: Arc<RwLock<HashMap<String, Token>>>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct Client {
189    pub client_id: String,
190    pub client_secret: Option<String>,
191    pub redirect_uris: Vec<String>,
192    pub grant_types: Vec<String>,
193    pub response_types: Vec<String>,
194    pub scope: String,
195    pub token_endpoint_auth_method: String,
196    pub client_name: Option<String>,
197    pub client_uri: Option<String>,
198    pub logo_uri: Option<String>,
199    pub contacts: Vec<String>,
200    pub policy_uri: Option<String>,
201    pub tos_uri: Option<String>,
202    pub jwks: Option<serde_json::Value>,
203    pub jwks_uri: Option<String>,
204    pub software_id: Option<String>,
205    pub software_version: Option<String>,
206    pub registration_access_token: Option<String>,
207    pub registration_client_uri: Option<String>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct AuthorizationCode {
212    pub code: String,
213    pub client_id: String,
214    pub redirect_uri: String,
215    pub scope: String,
216    pub expires_at: chrono::DateTime<Utc>,
217    pub code_challenge: Option<String>,
218    pub code_challenge_method: Option<String>,
219    pub user_id: String,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct Token {
224    access_token: String,
225    refresh_token: Option<String>,
226    client_id: String,
227    scope: String,
228    expires_at: chrono::DateTime<Utc>,
229    user_id: String,
230    revoked: bool,
231}
232
233impl AppState {
234    pub fn new(config: IssuerConfig) -> Self {
235        let base_url = format!("{}://{}:{}", config.scheme, config.host, config.port);
236        Self {
237            config: Arc::new(config),
238            clients: Arc::new(RwLock::new(HashMap::new())),
239            codes: Arc::new(RwLock::new(HashMap::new())),
240            tokens: Arc::new(RwLock::new(HashMap::new())),
241            refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
242            base_url,
243        }
244    }
245
246    pub fn issuer(&self) -> &str {
247        self.base_url.as_str()
248    }
249
250    pub fn register_client(
251        &self,
252        metadata: serde_json::Value,
253    ) -> Result<Client, (StatusCode, Json<serde_json::Value>)> {
254        let requested_scope = metadata
255            .get("scope")
256            .and_then(|v| v.as_str())
257            .unwrap_or("openid");
258
259        self.config
260            .validate_scope(requested_scope)
261            .map_err(|err| (StatusCode::BAD_REQUEST, Json(json!({"error": err}))))?;
262
263        let client_id = Uuid::new_v4().to_string();
264
265        let client_secret = if self.config.generate_client_secret_for_dcr
266            || metadata
267                .get("token_endpoint_auth_method")
268                .and_then(|v| v.as_str())
269                != Some("none")
270        {
271            Some(generate_token())
272        } else {
273            None
274        };
275
276        let redirect_uris = metadata
277            .get("redirect_uris")
278            .and_then(|v| v.as_array())
279            .map(|arr| {
280                arr.iter()
281                    .filter_map(|u| u.as_str().map(|s| s.to_string()))
282                    .collect::<Vec<String>>()
283            })
284            .unwrap_or_default();
285
286        if redirect_uris.is_empty()
287            && metadata.get("grant_types").map(|v| {
288                v.as_array()
289                    .map(|a| a.contains(&json!("client_credentials")))
290            }) != Some(Some(true))
291        {
292            return Err((
293                StatusCode::BAD_REQUEST,
294                Json(json!({"error": "redirect_uris required"})),
295            ));
296        }
297
298        let client = Client {
299            client_id: client_id.clone(),
300            client_secret: client_secret.clone(),
301            redirect_uris,
302            grant_types: metadata
303                .get("grant_types")
304                .and_then(|v| v.as_array())
305                .map(|arr| {
306                    arr.iter()
307                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
308                        .collect()
309                })
310                .unwrap_or_else(|| vec!["authorization_code".to_string()]),
311            response_types: metadata
312                .get("response_types")
313                .and_then(|v| v.as_array())
314                .map(|arr| {
315                    arr.iter()
316                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
317                        .collect()
318                })
319                .unwrap_or_else(|| vec!["code".to_string()]),
320            scope: metadata
321                .get("scope")
322                .and_then(|v| v.as_str())
323                .unwrap_or("")
324                .to_string(),
325            token_endpoint_auth_method: metadata
326                .get("token_endpoint_auth_method")
327                .and_then(|v| v.as_str())
328                .unwrap_or("client_secret_basic")
329                .to_string(),
330            client_name: metadata
331                .get("client_name")
332                .and_then(|v| v.as_str())
333                .map(|s| s.to_string()),
334            client_uri: metadata
335                .get("client_uri")
336                .and_then(|v| v.as_str())
337                .map(|s| s.to_string()),
338            logo_uri: metadata
339                .get("logo_uri")
340                .and_then(|v| v.as_str())
341                .map(|s| s.to_string()),
342            contacts: metadata
343                .get("contacts")
344                .and_then(|v| v.as_array())
345                .map(|arr| {
346                    arr.iter()
347                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
348                        .collect()
349                })
350                .unwrap_or_default(),
351            policy_uri: metadata
352                .get("policy_uri")
353                .and_then(|v| v.as_str())
354                .map(|s| s.to_string()),
355            tos_uri: metadata
356                .get("tos_uri")
357                .and_then(|v| v.as_str())
358                .map(|s| s.to_string()),
359            jwks: metadata.get("jwks").cloned(),
360            jwks_uri: metadata
361                .get("jwks_uri")
362                .and_then(|v| v.as_str())
363                .map(|s| s.to_string()),
364            software_id: metadata
365                .get("software_id")
366                .and_then(|v| v.as_str())
367                .map(|s| s.to_string()),
368            software_version: metadata
369                .get("software_version")
370                .and_then(|v| v.as_str())
371                .map(|s| s.to_string()),
372            registration_access_token: None,
373            registration_client_uri: Some(format!("{}/register/{}", self.issuer(), client_id)),
374        };
375
376        self.clients
377            .write()
378            .unwrap()
379            .insert(client_id.clone(), client.clone());
380
381        Ok(client)
382    }
383
384    pub fn generate_jwt(
385        &self,
386        client: &Client,
387        options: crate::testkit::JwtOptions,
388    ) -> Result<String, jsonwebtoken::errors::Error> {
389        let scope = options.scope.unwrap_or_else(|| client.scope.clone());
390        issue_jwt(
391            self.issuer(),
392            &client.client_id,
393            &options.user_id,
394            &scope,
395            options.expires_in,
396        )
397    }
398
399    pub fn router(self) -> Router {
400        let cors = build_cors_layer(&self.config);
401        Router::new()
402            .route(
403                "/.well-known/openid-configuration",
404                get(well_known_openid_configuration),
405            )
406            .route("/.well-known/jwks.json", get(jwks))
407            .route("/register", post(register_client))
408            .route("/register/{client_id}", get(get_client))
409            .route("/authorize", get(authorize))
410            .route("/token", post(token_endpoint))
411            .route("/introspect", post(introspect))
412            .route("/revoke", post(revoke))
413            .route("/userinfo", get(userinfo))
414            .route("/error", get(error_page))
415            .with_state(self)
416            .layer(cors)
417    }
418
419    pub async fn start(mut self) -> (SocketAddr, JoinHandle<()>) {
420        let port = self.config.port;
421        let addr = SocketAddr::from(([127, 0, 0, 1], port));
422        let listener = TcpListener::bind(addr).await.unwrap();
423        let local_addr = listener.local_addr().unwrap();
424        let base_url = format!(
425            "{}://{}:{}",
426            self.config.scheme,
427            self.config.host,
428            local_addr.port()
429        );
430        self.base_url = base_url;
431
432        let router = self.router();
433        let handle = tokio::spawn(async move {
434            axum::serve(listener, router).await.unwrap();
435        });
436        (local_addr, handle)
437    }
438}
439
440fn generate_code() -> String {
441    Uuid::new_v4().to_string()[..20].to_string()
442}
443
444fn generate_token() -> String {
445    format!("tok_{}", Uuid::new_v4().to_string().replace("-", ""))
446}
447
448fn issue_jwt(
449    issuer: &str,
450    client_id: &str,
451    user_id: &str,
452    requested_scope: &str,
453    expires_in: i64,
454) -> Result<String, jsonwebtoken::errors::Error> {
455    let iat = Utc::now().timestamp() as usize;
456    let exp = (Utc::now() + Duration::seconds(expires_in)).timestamp() as usize;
457
458    // Filter and clean up requested scopes
459    let scopes: Vec<&str> = requested_scope.split_whitespace().collect();
460
461    // Construct the JWT claims (payload)
462    let claims = Claims {
463        iss: issuer.to_string(),
464        sub: user_id.to_string(),
465        aud: client_id.to_string(),
466        exp,
467        iat,
468        scope: Some(scopes.join(" ")), // Only requested scopes
469        auth_time: Some(iat),
470        // New claims
471        typ: "Bearer".to_string(), // This is part of the payload
472        azp: Some(client_id.to_string()),
473        sid: Some(format!("sid-{}", Uuid::new_v4())),
474        jti: Uuid::new_v4().to_string(),
475    };
476
477    // Build the JWT header with typ: JWT
478    let mut header = Header::new(Algorithm::RS256);
479    header.typ = Some("JWT".to_string());
480    header.kid = Some(KID.to_string());
481
482    // Encode (sign) the token
483    encode(&header, &claims, &KEYS.encoding)
484}
485
486// === Global RSA Key (for JWT signing) ===
487pub static KEYS: Lazy<Keys> = Lazy::new(|| {
488    let mut rng = rand::thread_rng();
489    let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate key");
490    let public_key = private_key.to_public_key();
491
492    // Generate PEMs
493    let private_pem = private_key
494        .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
495        .expect("failed to encode private key")
496        .to_string();
497
498    let public_pem = public_key
499        .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
500        .expect("failed to encode public key")
501        .to_string();
502
503    let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap();
504    let decoding_key = DecodingKey::from_rsa_pem(public_pem.as_bytes()).unwrap();
505
506    Keys {
507        encoding: encoding_key,
508        decoding: decoding_key,
509        public_pem, // Store it
510    }
511});
512
513static JWKS_JSON: Lazy<serde_json::Value> = Lazy::new(|| {
514    let public_key = rsa::RsaPublicKey::from_public_key_pem(&KEYS.public_pem)
515        .expect("Failed to parse stored public key");
516
517    let jwk = Jwk {
518        common: CommonParameters {
519            key_algorithm: Some(jsonwebtoken::jwk::KeyAlgorithm::RS256),
520            key_id: Some(KID.to_string()),
521            ..Default::default()
522        },
523        algorithm: jsonwebtoken::jwk::AlgorithmParameters::RSA(
524            jsonwebtoken::jwk::RSAKeyParameters {
525                n: general_purpose::STANDARD.encode(public_key.n().to_bytes_be()),
526                e: general_purpose::STANDARD.encode(public_key.e().to_bytes_be()),
527                key_type: jsonwebtoken::jwk::RSAKeyType::RSA,
528            },
529        ),
530    };
531
532    json!({ "keys": [jwk] })
533});
534
535#[allow(unused)]
536pub struct Keys {
537    pub encoding: EncodingKey,
538    pub decoding: DecodingKey,
539    pub public_pem: String,
540}
541
542#[derive(Debug, Clone, Serialize, Deserialize)]
543struct AccessToken {
544    access_token: String,
545    token_type: String,
546    expires_in: i64,
547    scope: Option<String>,
548    refresh_token: Option<String>,
549    id_token: Option<String>,
550}
551
552#[derive(Debug, Clone, Serialize, Deserialize)]
553struct Claims {
554    iss: String,
555    sub: String,
556    aud: String,
557    exp: usize,
558    iat: usize,
559    scope: Option<String>,
560    auth_time: Option<usize>,
561    typ: String,         // Token type, e.g., "Bearer"
562    azp: Option<String>, // Authorized party (client_id)
563    sid: Option<String>, // Session ID
564    jti: String,         // Unique token ID
565}
566
567// 'oauth-authorization-server' | 'oauth-protected-resource' | 'openid-configuration',
568async fn well_known_openid_configuration(State(state): State<AppState>) -> impl IntoResponse {
569    let discovery = state.config.to_discovery_document(state.base_url);
570    (StatusCode::OK, Json(discovery))
571}
572
573async fn jwks() -> impl IntoResponse {
574    (StatusCode::OK, Json(JWKS_JSON.clone()))
575}
576
577async fn register_client(
578    State(state): State<AppState>,
579    Json(metadata): Json<serde_json::Value>,
580) -> impl IntoResponse {
581    let requested_scope = metadata
582        .get("scope")
583        .and_then(|v| v.as_str())
584        .unwrap_or("openid");
585
586    match state.config.validate_scope(requested_scope) {
587        Ok(_) => { /* continue */ }
588        Err(e) => {
589            return (
590                StatusCode::BAD_REQUEST,
591                Json(json!({ "error": "invalid_scope", "error_description": e })),
592            );
593        }
594    };
595
596    let client_id = Uuid::new_v4().to_string();
597
598    let client_secret = if state.config.generate_client_secret_for_dcr
599        || metadata
600            .get("token_endpoint_auth_method")
601            .and_then(|v| v.as_str())
602            != Some("none")
603    {
604        Some(generate_token())
605    } else {
606        None
607    };
608
609    let redirect_uris = metadata
610        .get("redirect_uris")
611        .and_then(|v| v.as_array())
612        .map(|arr| {
613            arr.iter()
614                .filter_map(|u| u.as_str().map(|s| s.to_string()))
615                .collect::<Vec<String>>()
616        })
617        .unwrap_or_default();
618
619    if redirect_uris.is_empty()
620        && metadata.get("grant_types").map(|v| {
621            v.as_array()
622                .map(|a| a.contains(&json!("client_credentials")))
623        }) != Some(Some(true))
624    {
625        return (
626            StatusCode::BAD_REQUEST,
627            Json(json!({"error": "redirect_uris required"})),
628        );
629    }
630
631    let client = Client {
632        client_id: client_id.clone(),
633        client_secret: client_secret.clone(),
634        redirect_uris,
635        grant_types: metadata
636            .get("grant_types")
637            .and_then(|v| v.as_array())
638            .map(|arr| {
639                arr.iter()
640                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
641                    .collect()
642            })
643            .unwrap_or_else(|| vec!["authorization_code".to_string()]),
644        response_types: metadata
645            .get("response_types")
646            .and_then(|v| v.as_array())
647            .map(|arr| {
648                arr.iter()
649                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
650                    .collect()
651            })
652            .unwrap_or_else(|| vec!["code".to_string()]),
653        scope: metadata
654            .get("scope")
655            .and_then(|v| v.as_str())
656            .unwrap_or("")
657            .to_string(),
658        token_endpoint_auth_method: metadata
659            .get("token_endpoint_auth_method")
660            .and_then(|v| v.as_str())
661            .unwrap_or("client_secret_basic")
662            .to_string(),
663        client_name: metadata
664            .get("client_name")
665            .and_then(|v| v.as_str())
666            .map(|s| s.to_string()),
667        client_uri: metadata
668            .get("client_uri")
669            .and_then(|v| v.as_str())
670            .map(|s| s.to_string()),
671        logo_uri: metadata
672            .get("logo_uri")
673            .and_then(|v| v.as_str())
674            .map(|s| s.to_string()),
675        contacts: metadata
676            .get("contacts")
677            .and_then(|v| v.as_array())
678            .map(|arr| {
679                arr.iter()
680                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
681                    .collect()
682            })
683            .unwrap_or_default(),
684        policy_uri: metadata
685            .get("policy_uri")
686            .and_then(|v| v.as_str())
687            .map(|s| s.to_string()),
688        tos_uri: metadata
689            .get("tos_uri")
690            .and_then(|v| v.as_str())
691            .map(|s| s.to_string()),
692        jwks: metadata.get("jwks").cloned(),
693        jwks_uri: metadata
694            .get("jwks_uri")
695            .and_then(|v| v.as_str())
696            .map(|s| s.to_string()),
697        software_id: metadata
698            .get("software_id")
699            .and_then(|v| v.as_str())
700            .map(|s| s.to_string()),
701        software_version: metadata
702            .get("software_version")
703            .and_then(|v| v.as_str())
704            .map(|s| s.to_string()),
705        registration_access_token: None,
706        registration_client_uri: Some(format!("{}/register/{}", state.issuer(), client_id)),
707    };
708
709    state
710        .clients
711        .write()
712        .unwrap()
713        .insert(client_id.clone(), client.clone());
714
715    let response = json!({
716        "client_id": client.client_id,
717        "client_secret": client.client_secret,
718        "client_id_issued_at": Utc::now().timestamp(),
719        "registration_client_uri": client.registration_client_uri,
720        "registration_access_token": Uuid::new_v4().to_string(),
721        "redirect_uris": client.redirect_uris,
722        "grant_types": client.grant_types,
723        "response_types": client.response_types,
724        "scope": client.scope,
725        "token_endpoint_auth_method": client.token_endpoint_auth_method
726    });
727
728    (StatusCode::CREATED, Json(response))
729}
730
731async fn get_client(
732    State(state): State<AppState>,
733    Path(client_id): Path<String>,
734) -> impl IntoResponse {
735    if let Some(client) = state.clients.read().unwrap().get(&client_id) {
736        let response = json!({
737            "client_id": client.client_id,
738            "client_name": client.client_name,
739            "redirect_uris": client.redirect_uris,
740            "grant_types": client.grant_types,
741            "scope": client.scope
742        });
743        (StatusCode::OK, Json(response))
744    } else {
745        (
746            StatusCode::NOT_FOUND,
747            Json(json!({"error": "client not found"})),
748        )
749    }
750}
751
752#[derive(Deserialize)]
753struct AuthorizeQuery {
754    response_type: String,
755    client_id: String,
756    redirect_uri: Option<String>,
757    scope: Option<String>,
758    state: Option<String>,
759    code_challenge: Option<String>,
760    code_challenge_method: Option<String>,
761}
762
763async fn authorize(
764    State(state): State<AppState>,
765    Query(params): Query<AuthorizeQuery>,
766) -> impl IntoResponse {
767    let clients = state.clients.read().unwrap();
768    let client = match clients.get(&params.client_id) {
769        Some(c) => c,
770        None => {
771            return Redirect::to(&format!(
772                "/error?error=invalid_client&state={}",
773                params.state.as_deref().unwrap_or("")
774            ))
775            .into_response();
776        }
777    };
778
779    if params.response_type != "code" {
780        return Redirect::to(&format!(
781            "/error?error=unsupported_response_type&state={}",
782            params.state.as_deref().unwrap_or("")
783        ))
784        .into_response();
785    }
786
787    let redirect_uri = match &params.redirect_uri {
788        Some(uri) => {
789            if !client.redirect_uris.contains(uri) {
790                return Redirect::to(&format!(
791                    "/error?error=invalid_request&state={}",
792                    params.state.as_deref().unwrap_or("")
793                ))
794                .into_response();
795            }
796            uri.clone()
797        }
798        None => client.redirect_uris.first().unwrap().clone(),
799    };
800
801    let code = generate_code();
802
803    // Compute the allowed scope: intersection of requested and registered
804    let requested_scopes: HashSet<String> = params
805        .scope
806        .clone()
807        .unwrap_or_default()
808        .split_whitespace()
809        .map(|s| s.to_string())
810        .collect();
811
812    let registered_scopes: HashSet<String> = client
813        .scope
814        .split_whitespace()
815        .map(|s| s.to_string())
816        .collect();
817
818    let granted_scopes: Vec<String> = requested_scopes
819        .intersection(&registered_scopes)
820        .cloned()
821        .collect();
822
823    let final_scope = granted_scopes.join(" ");
824
825    let auth_code = AuthorizationCode {
826        code: code.clone(),
827        client_id: params.client_id.clone(),
828        redirect_uri: redirect_uri.clone(),
829        scope: final_scope, // use filtered scopes
830        expires_at: Utc::now() + Duration::minutes(10),
831        code_challenge: params.code_challenge.clone(),
832        code_challenge_method: params.code_challenge_method.clone(),
833        user_id: "test-user-123".to_string(),
834    };
835
836    state.codes.write().unwrap().insert(code.clone(), auth_code);
837
838    let redirect_url = format!(
839        "{}?code={}&state={}",
840        redirect_uri,
841        code,
842        params.state.as_deref().unwrap_or("")
843    );
844
845    Redirect::to(&redirect_url).into_response()
846}
847
848#[derive(Deserialize)]
849struct TokenRequest {
850    grant_type: String,
851    code: Option<String>,
852    _redirect_uri: Option<String>,
853    client_id: Option<String>,
854    _client_secret: Option<String>,
855    refresh_token: Option<String>,
856    code_verifier: Option<String>,
857    scope: Option<String>,
858}
859
860async fn token_endpoint(
861    State(state): State<AppState>,
862    _headers: HeaderMap,
863    Form(form): Form<TokenRequest>,
864) -> impl IntoResponse {
865    if form.grant_type == "authorization_code" {
866        let code = form.code.as_deref().unwrap_or("");
867        let code_obj = match state.codes.write().unwrap().remove(code) {
868            Some(c) => c,
869            None => {
870                return (
871                    StatusCode::BAD_REQUEST,
872                    Json(json!({"error": "invalid_grant"})),
873                )
874                    .into_response();
875            }
876        };
877
878        if code_obj.expires_at < Utc::now() {
879            return (
880                StatusCode::BAD_REQUEST,
881                Json(json!({"error": "invalid_grant"})),
882            )
883                .into_response();
884        }
885
886        if let (Some(challenge), Some(verifier)) = (&code_obj.code_challenge, &form.code_verifier) {
887            let method = code_obj.code_challenge_method.as_deref().unwrap_or("plain");
888            let computed = if method == "S256" {
889                general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(verifier.as_bytes()))
890            } else {
891                verifier.clone()
892            };
893            if computed != *challenge {
894                return (
895                    StatusCode::BAD_REQUEST,
896                    Json(json!({"error": "invalid_grant"})),
897                )
898                    .into_response();
899            }
900        }
901
902        // let access_token = generate_token();
903        let refresh_token = generate_token();
904
905        let jwt = issue_jwt(
906            state.issuer(),
907            &code_obj.client_id,
908            &code_obj.user_id,
909            &code_obj.scope,
910            3600,
911        )
912        .unwrap();
913
914        let token = Token {
915            access_token: jwt.clone(),
916            refresh_token: Some(refresh_token.clone()),
917            client_id: code_obj.client_id.clone(),
918            scope: code_obj.scope.clone(),
919            expires_at: Utc::now() + Duration::hours(1),
920            user_id: code_obj.user_id.clone(),
921            revoked: false,
922        };
923
924        state
925            .tokens
926            .write()
927            .unwrap()
928            .insert(jwt.clone(), token.clone());
929        state
930            .refresh_tokens
931            .write()
932            .unwrap()
933            .insert(refresh_token.clone(), token);
934
935        let response = json!({
936            "access_token": jwt,
937            "token_type": "Bearer",
938            "expires_in": 3600,
939            "refresh_token": refresh_token,
940            "scope": code_obj.scope
941        });
942
943        (StatusCode::OK, Json(response)).into_response()
944    } else if form.grant_type == "refresh_token" {
945        let rt = form.refresh_token.as_deref().unwrap_or("");
946        let mut guard = state.refresh_tokens.write().unwrap();
947        let token = match guard.get_mut(rt) {
948            Some(t) => t,
949            None => {
950                return (
951                    StatusCode::BAD_REQUEST,
952                    Json(json!({"error": "invalid_grant"})),
953                )
954                    .into_response();
955            }
956        };
957
958        if token.revoked {
959            return (
960                StatusCode::BAD_REQUEST,
961                Json(json!({"error": "invalid_grant"})),
962            )
963                .into_response();
964        }
965
966        let new_access_token = issue_jwt(
967            state.issuer(),
968            &token.client_id,
969            &token.user_id,
970            &token.scope,
971            3600,
972        )
973        .unwrap();
974        let new_refresh_token = generate_token();
975
976        let new_token = Token {
977            access_token: new_access_token.clone(),
978            refresh_token: Some(new_refresh_token.clone()),
979            client_id: token.client_id.clone(),
980            scope: token.scope.clone(),
981            expires_at: Utc::now() + Duration::hours(1),
982            user_id: token.user_id.clone(),
983            revoked: false,
984        };
985
986        state
987            .tokens
988            .write()
989            .unwrap()
990            .insert(new_access_token.clone(), new_token.clone());
991        state
992            .refresh_tokens
993            .write()
994            .unwrap()
995            .insert(new_refresh_token.clone(), new_token);
996
997        token.revoked = true;
998
999        let response = json!({
1000            "access_token": new_access_token,
1001            "token_type": "Bearer",
1002            "expires_in": 3600,
1003            "refresh_token": new_refresh_token,
1004            "scope": token.scope
1005        });
1006
1007        (StatusCode::OK, Json(response)).into_response()
1008    } else if form.grant_type == "client_credentials" {
1009        let client_id = form.client_id.as_deref().unwrap_or("");
1010        let clients = state.clients.read().unwrap();
1011        let client = match clients.get(client_id) {
1012            Some(c) => c,
1013            None => {
1014                return (
1015                    StatusCode::BAD_REQUEST,
1016                    Json(json!({"error": "invalid_client"})),
1017                )
1018                    .into_response();
1019            }
1020        };
1021
1022        // Requested scopes from request
1023        let requested_scopes: HashSet<String> = form
1024            .scope
1025            .as_deref()
1026            .unwrap_or("")
1027            .split_whitespace()
1028            .map(|s| s.to_string())
1029            .collect();
1030
1031        if let Some(requested_scope) = form.scope.as_deref() {
1032            // 1. Must be supported by the issuer
1033            if let Err(e) = state.config.validate_scope(requested_scope) {
1034                return (
1035                    StatusCode::BAD_REQUEST,
1036                    Json(json!({
1037                        "error": "invalid_scope",
1038                        "error_description": e
1039                    })),
1040                )
1041                    .into_response();
1042            }
1043
1044            // 2. Must be allowed for this client
1045            let client_scopes: HashSet<_> = client.scope.split_whitespace().collect();
1046            let requested_scopes: HashSet<_> = requested_scope.split_whitespace().collect();
1047
1048            let not_permitted: Vec<_> = requested_scopes
1049                .difference(&client_scopes)
1050                .cloned()
1051                .collect();
1052            if !not_permitted.is_empty() {
1053                return (StatusCode::BAD_REQUEST, Json(json!({
1054                        "error": "invalid_scope",
1055                        "error_description": format!("Client not authorized for scopes: {}", not_permitted.join(" "))
1056                    }))).into_response();
1057            }
1058        }
1059
1060        // Allowed scopes from registration
1061        let registered_scopes: HashSet<String> = client
1062            .scope
1063            .split_whitespace()
1064            .map(|s| s.to_string())
1065            .collect();
1066
1067        // Intersection (only allowed scopes)
1068        let granted_scopes: Vec<String> = requested_scopes
1069            .intersection(&registered_scopes)
1070            .cloned()
1071            .collect();
1072
1073        // If none of the requested scopes are allowed, return an error
1074        if granted_scopes.is_empty() && !requested_scopes.is_empty() {
1075            return (
1076                StatusCode::BAD_REQUEST,
1077                Json(json!({
1078                    "error": "invalid_scope",
1079                    "error_description": "Requested scopes not allowed for this client"
1080                })),
1081            )
1082                .into_response();
1083        }
1084
1085        // Final scope string
1086        let final_scope = if requested_scopes.is_empty() {
1087            // No scope was requested, issue default from registration
1088            client.scope.clone()
1089        } else {
1090            granted_scopes.join(" ")
1091        };
1092
1093        // Issue JWT with only granted scopes
1094        let access_token =
1095            issue_jwt(state.issuer(), client_id, "client", &final_scope, 3600).unwrap();
1096
1097        let response = json!({
1098            "access_token": access_token,
1099            "token_type": "Bearer",
1100            "expires_in": 3600,
1101            "scope": final_scope
1102        });
1103
1104        (StatusCode::OK, Json(response)).into_response()
1105    } else {
1106        (
1107            StatusCode::BAD_REQUEST,
1108            Json(json!({"error": "unsupported_grant_type"})),
1109        )
1110            .into_response()
1111    }
1112}
1113
1114async fn introspect(
1115    State(state): State<AppState>,
1116    Form(form): Form<HashMap<String, String>>,
1117) -> impl IntoResponse {
1118    let token = form.get("token").cloned().unwrap_or_default();
1119    let token_data = state.tokens.read().unwrap().get(&token).cloned();
1120
1121    if let Some(t) = token_data {
1122        let active = !t.revoked && t.expires_at > Utc::now();
1123        let response = json!({
1124            "active": active,
1125            "client_id": t.client_id,
1126            "scope": t.scope,
1127            "sub": t.user_id,
1128            "exp": t.expires_at.timestamp(),
1129            "iat": (t.expires_at - Duration::hours(1)).timestamp(),
1130            "token_type": "Bearer"
1131        });
1132        (StatusCode::OK, Json(response)).into_response()
1133    } else {
1134        (StatusCode::OK, Json(json!({"active": false}))).into_response()
1135    }
1136}
1137
1138async fn revoke(
1139    State(state): State<AppState>,
1140    Form(form): Form<HashMap<String, String>>,
1141) -> impl IntoResponse {
1142    let token = form.get("token").cloned().unwrap_or_default();
1143    if let Some(t) = state.tokens.write().unwrap().get_mut(&token) {
1144        t.revoked = true;
1145    }
1146    if let Some(t) = state.refresh_tokens.write().unwrap().get_mut(&token) {
1147        t.revoked = true;
1148    }
1149    (StatusCode::OK, Json(json!({}))).into_response()
1150}
1151
1152async fn userinfo(headers: HeaderMap, State(state): State<AppState>) -> impl IntoResponse {
1153    let auth = headers.get("Authorization").and_then(|v| v.to_str().ok());
1154    if let Some(auth) = auth {
1155        if let Some(token) = auth.strip_prefix("Bearer ") {
1156            if let Some(t) = state.tokens.read().unwrap().get(token) {
1157                if t.revoked || t.expires_at < Utc::now() {
1158                    return (
1159                        StatusCode::UNAUTHORIZED,
1160                        Json(json!({"error": "invalid_token"})),
1161                    )
1162                        .into_response();
1163                }
1164                let response = json!({
1165                    "sub": t.user_id,
1166                    "name": "Test User",
1167                    "email": "test@example.com",
1168                    "picture": "https://example.com/avatar.jpg"
1169                });
1170                return (StatusCode::OK, Json(response)).into_response();
1171            }
1172        }
1173    }
1174    (
1175        StatusCode::UNAUTHORIZED,
1176        Json(json!({"error": "invalid_token"})),
1177    )
1178        .into_response()
1179}
1180
1181async fn error_page(Query(params): Query<HashMap<String, String>>) -> Html<String> {
1182    let error = params.get("error").map(|s| s.as_str()).unwrap_or("unknown");
1183    Html(format!("<h1>OAuth Error: {}</h1>", error))
1184}
1185
1186fn build_cors_layer(config: &IssuerConfig) -> CorsLayer {
1187    let allowed_origins: Vec<HeaderValue> = config
1188        .allowed_origins
1189        .iter()
1190        .filter_map(|s| s.parse().ok())
1191        .collect();
1192
1193    let allowed = if allowed_origins.is_empty() {
1194        AllowOrigin::any()
1195    } else {
1196        AllowOrigin::list(allowed_origins)
1197    };
1198
1199    let allowed_methods =
1200        AllowMethods::list([http::Method::GET, http::Method::POST, http::Method::OPTIONS]);
1201
1202    let allowed_headers = AllowHeaders::list([
1203        header::AUTHORIZATION,
1204        header::CONTENT_TYPE,
1205        header::ACCEPT,
1206        "x-requested-with".parse().unwrap(),
1207    ]);
1208
1209    CorsLayer::new()
1210        .allow_origin(allowed)
1211        .allow_methods(allowed_methods)
1212        .allow_headers(allowed_headers)
1213        .allow_credentials(true)
1214        .max_age(std::time::Duration::from_secs(86400))
1215}