oauth2_test_server/
server.rs

1use axum::{
2    extract::{Form, Json, Path, Query, State},
3    http::{header, HeaderMap, StatusCode},
4    response::{Html, IntoResponse, Redirect},
5    routing::{get, post},
6    Router,
7};
8use base64::{engine::general_purpose, Engine};
9use chrono::{Duration, Utc};
10use http::HeaderValue;
11use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header};
12use jsonwebtoken::{
13    jwk::{CommonParameters, Jwk},
14    Validation,
15};
16use once_cell::sync::Lazy;
17use rsa::pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey};
18use rsa::traits::PublicKeyParts;
19use rsa::RsaPrivateKey;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22use sha2::Digest;
23use std::{
24    collections::{HashMap, HashSet},
25    net::SocketAddr,
26    sync::{Arc, RwLock},
27};
28use tokio::{net::TcpListener, task::JoinHandle};
29use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
30use uuid::Uuid;
31
32pub const KID: &str = "oauth-by-rustmcp";
33
34// src/config.rs
35use serde_json::Value;
36
37#[derive(Debug, Clone)]
38pub struct IssuerConfig {
39    pub scheme: String,
40    pub host: String,
41    pub port: u16,
42
43    // OIDC / OAuth capabilities
44    pub scopes_supported: HashSet<String>,
45    pub claims_supported: Vec<String>,
46    pub grant_types_supported: HashSet<String>,
47    pub response_types_supported: HashSet<String>,
48    pub token_endpoint_auth_methods_supported: HashSet<String>,
49    pub code_challenge_methods_supported: HashSet<String>,
50    pub subject_types_supported: Vec<String>,
51    pub id_token_signing_alg_values_supported: Vec<String>,
52    pub generate_client_secret_for_dcr: bool,
53    pub allowed_origins: Vec<String>,
54}
55
56impl Default for IssuerConfig {
57    fn default() -> Self {
58        let mut scopes = HashSet::new();
59        scopes.extend([
60            "openid".into(),
61            "profile".into(),
62            "email".into(),
63            "offline_access".into(),
64            "address".into(),
65            "phone".into(),
66        ]);
67
68        let mut grants = HashSet::new();
69        grants.extend([
70            "authorization_code".into(),
71            "refresh_token".into(),
72            "client_credentials".into(),
73        ]);
74
75        let mut auth_methods = HashSet::new();
76        auth_methods.extend([
77            "client_secret_basic".into(),
78            "client_secret_post".into(),
79            "none".into(),
80            "private_key_jwt".into(),
81        ]);
82
83        Self {
84            scheme: "http".into(),
85            host: "localhost".into(),
86            port: 0, // random
87            scopes_supported: scopes,
88            claims_supported: vec![
89                "sub".into(),
90                "name".into(),
91                "given_name".into(),
92                "family_name".into(),
93                "email".into(),
94                "email_verified".into(),
95                "picture".into(),
96                "locale".into(),
97            ],
98            generate_client_secret_for_dcr: true,
99            grant_types_supported: grants,
100            response_types_supported: ["code".into(), "token".into(), "id_token".into()].into(),
101            token_endpoint_auth_methods_supported: auth_methods,
102            code_challenge_methods_supported: ["plain".into(), "S256".into()].into(),
103            subject_types_supported: vec!["public".into()],
104            id_token_signing_alg_values_supported: vec!["RS256".into()],
105            allowed_origins: vec![
106                "http://localhost:3001".to_string(),
107                "http://localhost:8080".to_string(),
108                "http://localhost:6274".to_string(),
109            ],
110        }
111    }
112}
113
114impl IssuerConfig {
115    pub fn to_discovery_document(&self, issuer: String) -> Value {
116        let iss = issuer;
117        json!({
118            "issuer": iss,
119            "authorization_endpoint": format!("{}/authorize", iss),
120            "token_endpoint": format!("{}/token", iss),
121            "userinfo_endpoint": format!("{}/userinfo", iss),
122            "jwks_uri": format!("{}/.well-known/jwks.json", iss),
123            "registration_endpoint": format!("{}/register", iss),
124            "revocation_endpoint": format!("{}/revoke", iss),
125            "introspection_endpoint": format!("{}/introspect", iss),
126            "scopes_supported": self.scopes_supported.iter().collect::<Vec<_>>(),
127            "claims_supported": &self.claims_supported,
128            "grant_types_supported": self.grant_types_supported.iter().collect::<Vec<_>>(),
129            "response_types_supported": self.response_types_supported.iter().collect::<Vec<_>>(),
130            "token_endpoint_auth_methods_supported": self.token_endpoint_auth_methods_supported.iter().collect::<Vec<_>>(),
131            "code_challenge_methods_supported": self.code_challenge_methods_supported.iter().collect::<Vec<_>>(),
132            "subject_types_supported": &self.subject_types_supported,
133            "id_token_signing_alg_values_supported": &self.id_token_signing_alg_values_supported,
134        })
135    }
136
137    pub fn validate_scope(&self, scope: &str) -> Result<String, String> {
138        let requested: HashSet<_> = scope.split_whitespace().map(|s| s.to_string()).collect();
139        let unknown: Vec<_> = requested
140            .difference(&self.scopes_supported)
141            .cloned()
142            .collect();
143        if unknown.is_empty() {
144            Ok(scope.to_string())
145        } else {
146            Err(format!("invalid_scope: {}", unknown.join(" ")))
147        }
148    }
149
150    pub fn validate_grant_type(&self, grant: &str) -> bool {
151        self.grant_types_supported.contains(grant)
152    }
153}
154
155// #[derive(Clone, Debug)]
156// pub struct Config {
157//     pub scheme: String,
158//     pub host: String,
159//     pub port: u16,
160//     pub generate_client_secret_for_dcr: bool,
161//     pub allowed_origins: Vec<String>,
162// }
163
164// impl Default for Config {
165//     fn default() -> Self {
166//         Self {
167//             scheme: "http".to_string(),
168//             host: "localhost".to_string(),
169//             port: 8090,
170//             generate_client_secret_for_dcr: true,
171//             allowed_origins: vec![
172//                 "http://localhost:3001".to_string(),
173//                 "http://localhost:8080".to_string(),
174//                 "http://localhost:6274".to_string(),
175//             ],
176//         }
177//     }
178// }
179
180#[derive(Clone)]
181pub struct AppState {
182    pub config: Arc<IssuerConfig>,
183    pub base_url: String,
184    pub clients: Arc<RwLock<HashMap<String, Client>>>,
185    pub codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
186    pub tokens: Arc<RwLock<HashMap<String, Token>>>,
187    pub refresh_tokens: Arc<RwLock<HashMap<String, Token>>>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct Client {
192    pub client_id: String,
193    pub client_secret: Option<String>,
194    pub redirect_uris: Vec<String>,
195    pub grant_types: Vec<String>,
196    pub response_types: Vec<String>,
197    pub scope: String,
198    pub token_endpoint_auth_method: String,
199    pub client_name: Option<String>,
200    pub client_uri: Option<String>,
201    pub logo_uri: Option<String>,
202    pub contacts: Vec<String>,
203    pub policy_uri: Option<String>,
204    pub tos_uri: Option<String>,
205    pub jwks: Option<serde_json::Value>,
206    pub jwks_uri: Option<String>,
207    pub software_id: Option<String>,
208    pub software_version: Option<String>,
209    pub registration_access_token: Option<String>,
210    pub registration_client_uri: Option<String>,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct AuthorizationCode {
215    pub code: String,
216    pub client_id: String,
217    pub redirect_uri: String,
218    pub scope: String,
219    pub expires_at: chrono::DateTime<Utc>,
220    pub code_challenge: Option<String>,
221    pub code_challenge_method: Option<String>,
222    pub user_id: String,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct Token {
227    pub access_token: String,
228    pub refresh_token: Option<String>,
229    pub client_id: String,
230    pub scope: String,
231    pub expires_at: chrono::DateTime<Utc>,
232    pub user_id: String,
233    pub revoked: bool,
234}
235
236impl AppState {
237    pub fn new(config: IssuerConfig) -> Self {
238        let base_url = format!("{}://{}:{}", config.scheme, config.host, config.port);
239        Self {
240            config: Arc::new(config),
241            clients: Arc::new(RwLock::new(HashMap::new())),
242            codes: Arc::new(RwLock::new(HashMap::new())),
243            tokens: Arc::new(RwLock::new(HashMap::new())),
244            refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
245            base_url,
246        }
247    }
248
249    pub fn issuer(&self) -> &str {
250        self.base_url.as_str()
251    }
252
253    pub fn register_client(
254        &self,
255        metadata: serde_json::Value,
256    ) -> Result<Client, (StatusCode, Json<serde_json::Value>)> {
257        let requested_scope = metadata
258            .get("scope")
259            .and_then(|v| v.as_str())
260            .unwrap_or("openid");
261
262        self.config
263            .validate_scope(requested_scope)
264            .map_err(|err| (StatusCode::BAD_REQUEST, Json(json!({"error": err}))))?;
265
266        let client_id = Uuid::new_v4().to_string();
267
268        let client_secret = if self.config.generate_client_secret_for_dcr
269            || metadata
270                .get("token_endpoint_auth_method")
271                .and_then(|v| v.as_str())
272                != Some("none")
273        {
274            Some(generate_token())
275        } else {
276            None
277        };
278
279        let redirect_uris = metadata
280            .get("redirect_uris")
281            .and_then(|v| v.as_array())
282            .map(|arr| {
283                arr.iter()
284                    .filter_map(|u| u.as_str().map(|s| s.to_string()))
285                    .collect::<Vec<String>>()
286            })
287            .unwrap_or_default();
288
289        if redirect_uris.is_empty()
290            && metadata.get("grant_types").map(|v| {
291                v.as_array()
292                    .map(|a| a.contains(&json!("client_credentials")))
293            }) != Some(Some(true))
294        {
295            return Err((
296                StatusCode::BAD_REQUEST,
297                Json(json!({"error": "redirect_uris required"})),
298            ));
299        }
300
301        let client = Client {
302            client_id: client_id.clone(),
303            client_secret: client_secret.clone(),
304            redirect_uris,
305            grant_types: metadata
306                .get("grant_types")
307                .and_then(|v| v.as_array())
308                .map(|arr| {
309                    arr.iter()
310                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
311                        .collect()
312                })
313                .unwrap_or_else(|| vec!["authorization_code".to_string()]),
314            response_types: metadata
315                .get("response_types")
316                .and_then(|v| v.as_array())
317                .map(|arr| {
318                    arr.iter()
319                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
320                        .collect()
321                })
322                .unwrap_or_else(|| vec!["code".to_string()]),
323            scope: metadata
324                .get("scope")
325                .and_then(|v| v.as_str())
326                .unwrap_or("")
327                .to_string(),
328            token_endpoint_auth_method: metadata
329                .get("token_endpoint_auth_method")
330                .and_then(|v| v.as_str())
331                .unwrap_or("client_secret_basic")
332                .to_string(),
333            client_name: metadata
334                .get("client_name")
335                .and_then(|v| v.as_str())
336                .map(|s| s.to_string()),
337            client_uri: metadata
338                .get("client_uri")
339                .and_then(|v| v.as_str())
340                .map(|s| s.to_string()),
341            logo_uri: metadata
342                .get("logo_uri")
343                .and_then(|v| v.as_str())
344                .map(|s| s.to_string()),
345            contacts: metadata
346                .get("contacts")
347                .and_then(|v| v.as_array())
348                .map(|arr| {
349                    arr.iter()
350                        .filter_map(|v| v.as_str().map(|s| s.to_string()))
351                        .collect()
352                })
353                .unwrap_or_default(),
354            policy_uri: metadata
355                .get("policy_uri")
356                .and_then(|v| v.as_str())
357                .map(|s| s.to_string()),
358            tos_uri: metadata
359                .get("tos_uri")
360                .and_then(|v| v.as_str())
361                .map(|s| s.to_string()),
362            jwks: metadata.get("jwks").cloned(),
363            jwks_uri: metadata
364                .get("jwks_uri")
365                .and_then(|v| v.as_str())
366                .map(|s| s.to_string()),
367            software_id: metadata
368                .get("software_id")
369                .and_then(|v| v.as_str())
370                .map(|s| s.to_string()),
371            software_version: metadata
372                .get("software_version")
373                .and_then(|v| v.as_str())
374                .map(|s| s.to_string()),
375            registration_access_token: None,
376            registration_client_uri: Some(format!("{}/register/{}", self.issuer(), client_id)),
377        };
378
379        self.clients
380            .write()
381            .unwrap()
382            .insert(client_id.clone(), client.clone());
383
384        Ok(client)
385    }
386
387    pub fn generate_token(
388        &self,
389        client: &Client,
390        options: crate::testkit::JwtOptions,
391    ) -> Result<Token, jsonwebtoken::errors::Error> {
392        let user_id = options.user_id.clone();
393        let jwt = self.generate_jwt(client, options)?;
394        let refresh_token = generate_token();
395        let token = Token {
396            access_token: jwt.clone(),
397            refresh_token: Some(refresh_token),
398            client_id: client.client_id.clone(),
399            scope: client.scope.clone(),
400            expires_at: Utc::now() + Duration::hours(1),
401            user_id,
402            revoked: false,
403        };
404
405        self.tokens
406            .write()
407            .unwrap()
408            .insert(jwt.clone(), token.clone());
409        Ok(token)
410    }
411    pub fn generate_jwt(
412        &self,
413        client: &Client,
414        options: crate::testkit::JwtOptions,
415    ) -> Result<String, jsonwebtoken::errors::Error> {
416        let scope = options.scope.unwrap_or_else(|| client.scope.clone());
417        issue_jwt(
418            self.issuer(),
419            &client.client_id,
420            &options.user_id,
421            &scope,
422            options.expires_in,
423        )
424    }
425
426    pub fn router(self) -> Router {
427        let cors = build_cors_layer(&self.config);
428        Router::new()
429            .route(
430                "/.well-known/openid-configuration",
431                get(well_known_openid_configuration),
432            )
433            .route("/.well-known/jwks.json", get(jwks))
434            .route("/register", post(register_client))
435            .route("/register/{client_id}", get(get_client))
436            .route("/authorize", get(authorize))
437            .route("/token", post(token_endpoint))
438            .route("/introspect", post(introspect))
439            .route("/revoke", post(revoke))
440            .route("/userinfo", get(userinfo))
441            .route("/error", get(error_page))
442            .with_state(self)
443            .layer(cors)
444    }
445
446    pub async fn start(mut self) -> (SocketAddr, JoinHandle<()>) {
447        let port = self.config.port;
448        let addr = SocketAddr::from(([127, 0, 0, 1], port));
449        let listener = TcpListener::bind(addr).await.unwrap();
450        let local_addr = listener.local_addr().unwrap();
451        let base_url = format!(
452            "{}://{}:{}",
453            self.config.scheme,
454            self.config.host,
455            local_addr.port()
456        );
457        self.base_url = base_url;
458
459        let router = self.router();
460        let handle = tokio::spawn(async move {
461            axum::serve(listener, router).await.unwrap();
462        });
463        (local_addr, handle)
464    }
465}
466
467fn generate_code() -> String {
468    Uuid::new_v4().to_string()[..20].to_string()
469}
470
471fn generate_token() -> String {
472    format!("tok_{}", Uuid::new_v4().to_string().replace("-", ""))
473}
474
475fn issue_jwt(
476    issuer: &str,
477    client_id: &str,
478    user_id: &str,
479    requested_scope: &str,
480    expires_in: i64,
481) -> Result<String, jsonwebtoken::errors::Error> {
482    let iat = Utc::now().timestamp() as usize;
483    let exp = (Utc::now() + Duration::seconds(expires_in)).timestamp() as usize;
484
485    // Filter and clean up requested scopes
486    let scopes: Vec<&str> = requested_scope.split_whitespace().collect();
487
488    // Construct the JWT claims (payload)
489    let claims = Claims {
490        iss: issuer.to_string(),
491        sub: user_id.to_string(),
492        aud: client_id.to_string(),
493        exp,
494        iat,
495        scope: Some(scopes.join(" ")), // Only requested scopes
496        auth_time: Some(iat),
497        // New claims
498        typ: "Bearer".to_string(), // This is part of the payload
499        azp: Some(client_id.to_string()),
500        sid: Some(format!("sid-{}", Uuid::new_v4())),
501        jti: Uuid::new_v4().to_string(),
502    };
503
504    // Build the JWT header with typ: JWT
505    let mut header = Header::new(Algorithm::RS256);
506    header.typ = Some("JWT".to_string());
507    header.kid = Some(KID.to_string());
508
509    // Encode (sign) the token
510    encode(&header, &claims, &KEYS.encoding)
511}
512
513// === Global RSA Key (for JWT signing) ===
514pub static KEYS: Lazy<Keys> = Lazy::new(|| {
515    let mut rng = rand::thread_rng();
516    let private_key = RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate key");
517    let public_key = private_key.to_public_key();
518
519    // Generate PEMs
520    let private_pem = private_key
521        .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
522        .expect("failed to encode private key")
523        .to_string();
524
525    let public_pem = public_key
526        .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
527        .expect("failed to encode public key")
528        .to_string();
529
530    let encoding_key = EncodingKey::from_rsa_pem(private_pem.as_bytes()).unwrap();
531    let decoding_key = DecodingKey::from_rsa_pem(public_pem.as_bytes()).unwrap();
532
533    Keys {
534        encoding: encoding_key,
535        decoding: decoding_key,
536        public_pem, // Store it
537    }
538});
539
540static JWKS_JSON: Lazy<serde_json::Value> = Lazy::new(|| {
541    let public_key = rsa::RsaPublicKey::from_public_key_pem(&KEYS.public_pem)
542        .expect("Failed to parse stored public key");
543
544    let jwk = Jwk {
545        common: CommonParameters {
546            key_algorithm: Some(jsonwebtoken::jwk::KeyAlgorithm::RS256),
547            key_id: Some(KID.to_string()),
548            ..Default::default()
549        },
550        algorithm: jsonwebtoken::jwk::AlgorithmParameters::RSA(
551            jsonwebtoken::jwk::RSAKeyParameters {
552                n: general_purpose::URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be()),
553                e: general_purpose::URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be()),
554                key_type: jsonwebtoken::jwk::RSAKeyType::RSA,
555            },
556        ),
557    };
558
559    json!({ "keys": [jwk] })
560});
561
562#[allow(unused)]
563pub struct Keys {
564    pub encoding: EncodingKey,
565    pub decoding: DecodingKey,
566    pub public_pem: String,
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570struct AccessToken {
571    access_token: String,
572    token_type: String,
573    expires_in: i64,
574    scope: Option<String>,
575    refresh_token: Option<String>,
576    id_token: Option<String>,
577}
578
579#[derive(Debug, Clone, Serialize, Deserialize)]
580struct Claims {
581    iss: String,
582    sub: String,
583    aud: String,
584    exp: usize,
585    iat: usize,
586    scope: Option<String>,
587    auth_time: Option<usize>,
588    typ: String,         // Token type, e.g., "Bearer"
589    azp: Option<String>, // Authorized party (client_id)
590    sid: Option<String>, // Session ID
591    jti: String,         // Unique token ID
592}
593
594// 'oauth-authorization-server' | 'oauth-protected-resource' | 'openid-configuration',
595async fn well_known_openid_configuration(State(state): State<AppState>) -> impl IntoResponse {
596    let discovery = state.config.to_discovery_document(state.base_url);
597    (StatusCode::OK, Json(discovery))
598}
599
600async fn jwks() -> impl IntoResponse {
601    (StatusCode::OK, Json(JWKS_JSON.clone()))
602}
603
604async fn register_client(
605    State(state): State<AppState>,
606    Json(metadata): Json<serde_json::Value>,
607) -> impl IntoResponse {
608    let requested_scope = metadata
609        .get("scope")
610        .and_then(|v| v.as_str())
611        .unwrap_or("openid");
612
613    match state.config.validate_scope(requested_scope) {
614        Ok(_) => { /* continue */ }
615        Err(e) => {
616            return (
617                StatusCode::BAD_REQUEST,
618                Json(json!({ "error": "invalid_scope", "error_description": e })),
619            );
620        }
621    };
622
623    let client_id = Uuid::new_v4().to_string();
624
625    let client_secret = if state.config.generate_client_secret_for_dcr
626        || metadata
627            .get("token_endpoint_auth_method")
628            .and_then(|v| v.as_str())
629            != Some("none")
630    {
631        Some(generate_token())
632    } else {
633        None
634    };
635
636    let redirect_uris = metadata
637        .get("redirect_uris")
638        .and_then(|v| v.as_array())
639        .map(|arr| {
640            arr.iter()
641                .filter_map(|u| u.as_str().map(|s| s.to_string()))
642                .collect::<Vec<String>>()
643        })
644        .unwrap_or_default();
645
646    if redirect_uris.is_empty()
647        && metadata.get("grant_types").map(|v| {
648            v.as_array()
649                .map(|a| a.contains(&json!("client_credentials")))
650        }) != Some(Some(true))
651    {
652        return (
653            StatusCode::BAD_REQUEST,
654            Json(json!({"error": "redirect_uris required"})),
655        );
656    }
657
658    let client = Client {
659        client_id: client_id.clone(),
660        client_secret: client_secret.clone(),
661        redirect_uris,
662        grant_types: metadata
663            .get("grant_types")
664            .and_then(|v| v.as_array())
665            .map(|arr| {
666                arr.iter()
667                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
668                    .collect()
669            })
670            .unwrap_or_else(|| vec!["authorization_code".to_string()]),
671        response_types: metadata
672            .get("response_types")
673            .and_then(|v| v.as_array())
674            .map(|arr| {
675                arr.iter()
676                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
677                    .collect()
678            })
679            .unwrap_or_else(|| vec!["code".to_string()]),
680        scope: metadata
681            .get("scope")
682            .and_then(|v| v.as_str())
683            .unwrap_or("")
684            .to_string(),
685        token_endpoint_auth_method: metadata
686            .get("token_endpoint_auth_method")
687            .and_then(|v| v.as_str())
688            .unwrap_or("client_secret_basic")
689            .to_string(),
690        client_name: metadata
691            .get("client_name")
692            .and_then(|v| v.as_str())
693            .map(|s| s.to_string()),
694        client_uri: metadata
695            .get("client_uri")
696            .and_then(|v| v.as_str())
697            .map(|s| s.to_string()),
698        logo_uri: metadata
699            .get("logo_uri")
700            .and_then(|v| v.as_str())
701            .map(|s| s.to_string()),
702        contacts: metadata
703            .get("contacts")
704            .and_then(|v| v.as_array())
705            .map(|arr| {
706                arr.iter()
707                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
708                    .collect()
709            })
710            .unwrap_or_default(),
711        policy_uri: metadata
712            .get("policy_uri")
713            .and_then(|v| v.as_str())
714            .map(|s| s.to_string()),
715        tos_uri: metadata
716            .get("tos_uri")
717            .and_then(|v| v.as_str())
718            .map(|s| s.to_string()),
719        jwks: metadata.get("jwks").cloned(),
720        jwks_uri: metadata
721            .get("jwks_uri")
722            .and_then(|v| v.as_str())
723            .map(|s| s.to_string()),
724        software_id: metadata
725            .get("software_id")
726            .and_then(|v| v.as_str())
727            .map(|s| s.to_string()),
728        software_version: metadata
729            .get("software_version")
730            .and_then(|v| v.as_str())
731            .map(|s| s.to_string()),
732        registration_access_token: None,
733        registration_client_uri: Some(format!("{}/register/{}", state.issuer(), client_id)),
734    };
735
736    state
737        .clients
738        .write()
739        .unwrap()
740        .insert(client_id.clone(), client.clone());
741
742    let response = json!({
743        "client_id": client.client_id,
744        "client_secret": client.client_secret,
745        "client_id_issued_at": Utc::now().timestamp(),
746        "registration_client_uri": client.registration_client_uri,
747        "registration_access_token": Uuid::new_v4().to_string(),
748        "redirect_uris": client.redirect_uris,
749        "grant_types": client.grant_types,
750        "response_types": client.response_types,
751        "scope": client.scope,
752        "token_endpoint_auth_method": client.token_endpoint_auth_method
753    });
754
755    (StatusCode::CREATED, Json(response))
756}
757
758async fn get_client(
759    State(state): State<AppState>,
760    Path(client_id): Path<String>,
761) -> impl IntoResponse {
762    if let Some(client) = state.clients.read().unwrap().get(&client_id) {
763        let response = json!({
764            "client_id": client.client_id,
765            "client_name": client.client_name,
766            "redirect_uris": client.redirect_uris,
767            "grant_types": client.grant_types,
768            "scope": client.scope
769        });
770        (StatusCode::OK, Json(response))
771    } else {
772        (
773            StatusCode::NOT_FOUND,
774            Json(json!({"error": "client not found"})),
775        )
776    }
777}
778
779#[derive(Deserialize)]
780struct AuthorizeQuery {
781    response_type: String,
782    client_id: String,
783    redirect_uri: Option<String>,
784    scope: Option<String>,
785    state: Option<String>,
786    code_challenge: Option<String>,
787    code_challenge_method: Option<String>,
788}
789
790async fn authorize(
791    State(state): State<AppState>,
792    Query(params): Query<AuthorizeQuery>,
793) -> impl IntoResponse {
794    let clients = state.clients.read().unwrap();
795    let client = match clients.get(&params.client_id) {
796        Some(c) => c,
797        None => {
798            return Redirect::to(&format!(
799                "/error?error=invalid_client&state={}",
800                params.state.as_deref().unwrap_or("")
801            ))
802            .into_response();
803        }
804    };
805
806    if params.response_type != "code" {
807        return Redirect::to(&format!(
808            "/error?error=unsupported_response_type&state={}",
809            params.state.as_deref().unwrap_or("")
810        ))
811        .into_response();
812    }
813
814    let redirect_uri = match &params.redirect_uri {
815        Some(uri) => {
816            if !client.redirect_uris.contains(uri) {
817                return Redirect::to(&format!(
818                    "/error?error=invalid_request&state={}",
819                    params.state.as_deref().unwrap_or("")
820                ))
821                .into_response();
822            }
823            uri.clone()
824        }
825        None => client.redirect_uris.first().unwrap().clone(),
826    };
827
828    let code = generate_code();
829
830    // Compute the allowed scope: intersection of requested and registered
831    let requested_scopes: HashSet<String> = params
832        .scope
833        .clone()
834        .unwrap_or_default()
835        .split_whitespace()
836        .map(|s| s.to_string())
837        .collect();
838
839    let registered_scopes: HashSet<String> = client
840        .scope
841        .split_whitespace()
842        .map(|s| s.to_string())
843        .collect();
844
845    let granted_scopes: Vec<String> = requested_scopes
846        .intersection(&registered_scopes)
847        .cloned()
848        .collect();
849
850    let final_scope = granted_scopes.join(" ");
851
852    let auth_code = AuthorizationCode {
853        code: code.clone(),
854        client_id: params.client_id.clone(),
855        redirect_uri: redirect_uri.clone(),
856        scope: final_scope, // use filtered scopes
857        expires_at: Utc::now() + Duration::minutes(10),
858        code_challenge: params.code_challenge.clone(),
859        code_challenge_method: params.code_challenge_method.clone(),
860        user_id: "test-user-123".to_string(),
861    };
862
863    state.codes.write().unwrap().insert(code.clone(), auth_code);
864
865    let redirect_url = format!(
866        "{}?code={}&state={}",
867        redirect_uri,
868        code,
869        params.state.as_deref().unwrap_or("")
870    );
871
872    Redirect::to(&redirect_url).into_response()
873}
874
875#[derive(Deserialize)]
876struct TokenRequest {
877    grant_type: String,
878    code: Option<String>,
879    _redirect_uri: Option<String>,
880    client_id: Option<String>,
881    _client_secret: Option<String>,
882    refresh_token: Option<String>,
883    code_verifier: Option<String>,
884    scope: Option<String>,
885}
886
887async fn token_endpoint(
888    State(state): State<AppState>,
889    _headers: HeaderMap,
890    Form(form): Form<TokenRequest>,
891) -> impl IntoResponse {
892    if form.grant_type == "authorization_code" {
893        let code = form.code.as_deref().unwrap_or("");
894        let code_obj = match state.codes.write().unwrap().remove(code) {
895            Some(c) => c,
896            None => {
897                return (
898                    StatusCode::BAD_REQUEST,
899                    Json(json!({"error": "invalid_grant"})),
900                )
901                    .into_response();
902            }
903        };
904
905        if code_obj.expires_at < Utc::now() {
906            return (
907                StatusCode::BAD_REQUEST,
908                Json(json!({"error": "invalid_grant"})),
909            )
910                .into_response();
911        }
912
913        if let (Some(challenge), Some(verifier)) = (&code_obj.code_challenge, &form.code_verifier) {
914            let method = code_obj.code_challenge_method.as_deref().unwrap_or("plain");
915            let computed = if method == "S256" {
916                general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(verifier.as_bytes()))
917            } else {
918                verifier.clone()
919            };
920            if computed != *challenge {
921                return (
922                    StatusCode::BAD_REQUEST,
923                    Json(json!({"error": "invalid_grant"})),
924                )
925                    .into_response();
926            }
927        }
928
929        // let access_token = generate_token();
930        let refresh_token = generate_token();
931
932        let jwt = issue_jwt(
933            state.issuer(),
934            &code_obj.client_id,
935            &code_obj.user_id,
936            &code_obj.scope,
937            3600,
938        )
939        .unwrap();
940
941        let token = Token {
942            access_token: jwt.clone(),
943            refresh_token: Some(refresh_token.clone()),
944            client_id: code_obj.client_id.clone(),
945            scope: code_obj.scope.clone(),
946            expires_at: Utc::now() + Duration::hours(1),
947            user_id: code_obj.user_id.clone(),
948            revoked: false,
949        };
950
951        state
952            .tokens
953            .write()
954            .unwrap()
955            .insert(jwt.clone(), token.clone());
956        state
957            .refresh_tokens
958            .write()
959            .unwrap()
960            .insert(refresh_token.clone(), token);
961
962        let response = json!({
963            "access_token": jwt,
964            "token_type": "Bearer",
965            "expires_in": 3600,
966            "refresh_token": refresh_token,
967            "scope": code_obj.scope
968        });
969
970        (StatusCode::OK, Json(response)).into_response()
971    } else if form.grant_type == "refresh_token" {
972        let rt = form.refresh_token.as_deref().unwrap_or("");
973        let mut guard = state.refresh_tokens.write().unwrap();
974        let token = match guard.get_mut(rt) {
975            Some(t) => t,
976            None => {
977                return (
978                    StatusCode::BAD_REQUEST,
979                    Json(json!({"error": "invalid_grant"})),
980                )
981                    .into_response();
982            }
983        };
984
985        if token.revoked {
986            return (
987                StatusCode::BAD_REQUEST,
988                Json(json!({"error": "invalid_grant"})),
989            )
990                .into_response();
991        }
992
993        let new_access_token = issue_jwt(
994            state.issuer(),
995            &token.client_id,
996            &token.user_id,
997            &token.scope,
998            3600,
999        )
1000        .unwrap();
1001        let new_refresh_token = generate_token();
1002
1003        let new_token = Token {
1004            access_token: new_access_token.clone(),
1005            refresh_token: Some(new_refresh_token.clone()),
1006            client_id: token.client_id.clone(),
1007            scope: token.scope.clone(),
1008            expires_at: Utc::now() + Duration::hours(1),
1009            user_id: token.user_id.clone(),
1010            revoked: false,
1011        };
1012
1013        state
1014            .tokens
1015            .write()
1016            .unwrap()
1017            .insert(new_access_token.clone(), new_token.clone());
1018        state
1019            .refresh_tokens
1020            .write()
1021            .unwrap()
1022            .insert(new_refresh_token.clone(), new_token);
1023
1024        token.revoked = true;
1025
1026        let response = json!({
1027            "access_token": new_access_token,
1028            "token_type": "Bearer",
1029            "expires_in": 3600,
1030            "refresh_token": new_refresh_token,
1031            "scope": token.scope
1032        });
1033
1034        (StatusCode::OK, Json(response)).into_response()
1035    } else if form.grant_type == "client_credentials" {
1036        let client_id = form.client_id.as_deref().unwrap_or("");
1037        let clients = state.clients.read().unwrap();
1038        let client = match clients.get(client_id) {
1039            Some(c) => c,
1040            None => {
1041                return (
1042                    StatusCode::BAD_REQUEST,
1043                    Json(json!({"error": "invalid_client"})),
1044                )
1045                    .into_response();
1046            }
1047        };
1048
1049        // Requested scopes from request
1050        let requested_scopes: HashSet<String> = form
1051            .scope
1052            .as_deref()
1053            .unwrap_or("")
1054            .split_whitespace()
1055            .map(|s| s.to_string())
1056            .collect();
1057
1058        if let Some(requested_scope) = form.scope.as_deref() {
1059            // 1. Must be supported by the issuer
1060            if let Err(e) = state.config.validate_scope(requested_scope) {
1061                return (
1062                    StatusCode::BAD_REQUEST,
1063                    Json(json!({
1064                        "error": "invalid_scope",
1065                        "error_description": e
1066                    })),
1067                )
1068                    .into_response();
1069            }
1070
1071            // 2. Must be allowed for this client
1072            let client_scopes: HashSet<_> = client.scope.split_whitespace().collect();
1073            let requested_scopes: HashSet<_> = requested_scope.split_whitespace().collect();
1074
1075            let not_permitted: Vec<_> = requested_scopes
1076                .difference(&client_scopes)
1077                .cloned()
1078                .collect();
1079            if !not_permitted.is_empty() {
1080                return (StatusCode::BAD_REQUEST, Json(json!({
1081                        "error": "invalid_scope",
1082                        "error_description": format!("Client not authorized for scopes: {}", not_permitted.join(" "))
1083                    }))).into_response();
1084            }
1085        }
1086
1087        // Allowed scopes from registration
1088        let registered_scopes: HashSet<String> = client
1089            .scope
1090            .split_whitespace()
1091            .map(|s| s.to_string())
1092            .collect();
1093
1094        // Intersection (only allowed scopes)
1095        let granted_scopes: Vec<String> = requested_scopes
1096            .intersection(&registered_scopes)
1097            .cloned()
1098            .collect();
1099
1100        // If none of the requested scopes are allowed, return an error
1101        if granted_scopes.is_empty() && !requested_scopes.is_empty() {
1102            return (
1103                StatusCode::BAD_REQUEST,
1104                Json(json!({
1105                    "error": "invalid_scope",
1106                    "error_description": "Requested scopes not allowed for this client"
1107                })),
1108            )
1109                .into_response();
1110        }
1111
1112        // Final scope string
1113        let final_scope = if requested_scopes.is_empty() {
1114            // No scope was requested, issue default from registration
1115            client.scope.clone()
1116        } else {
1117            granted_scopes.join(" ")
1118        };
1119
1120        // Issue JWT with only granted scopes
1121        let access_token =
1122            issue_jwt(state.issuer(), client_id, "client", &final_scope, 3600).unwrap();
1123
1124        let response = json!({
1125            "access_token": access_token,
1126            "token_type": "Bearer",
1127            "expires_in": 3600,
1128            "scope": final_scope
1129        });
1130
1131        (StatusCode::OK, Json(response)).into_response()
1132    } else {
1133        (
1134            StatusCode::BAD_REQUEST,
1135            Json(json!({"error": "unsupported_grant_type"})),
1136        )
1137            .into_response()
1138    }
1139}
1140
1141async fn introspect(
1142    State(state): State<AppState>,
1143    Form(form): Form<HashMap<String, String>>,
1144) -> impl IntoResponse {
1145    let token = match form.get("token") {
1146        Some(t) => t,
1147        None => {
1148            return (
1149                StatusCode::BAD_REQUEST,
1150                Json(json!({"error": "invalid_request"})),
1151            )
1152                .into_response()
1153        }
1154    };
1155
1156    // First check if token exists in our storage (optional but recommended)
1157    let stored_token = state.tokens.read().unwrap().get(token).cloned();
1158
1159    // Decode the JWT (without full validation if you want to allow expired tokens for introspection)
1160    let mut validation = Validation::new(Algorithm::RS256);
1161    validation.validate_exp = false; // Allow introspection of expired tokens
1162    validation.required_spec_claims.clear(); // Be permissive
1163    validation.validate_aud = false;
1164
1165    match jsonwebtoken::decode::<Claims>(token, &KEYS.decoding, &validation) {
1166        Ok(token_data) => {
1167            let claims = token_data.claims;
1168
1169            // Determine if active based on your internal state + expiry
1170            let is_expired = Utc::now().timestamp() > claims.exp as i64;
1171            let is_revoked = stored_token.map(|t| t.revoked).unwrap_or(false);
1172            let active = !is_revoked && !is_expired;
1173
1174            let mut response = json!({
1175                "active": active,
1176                "scope": claims.scope,
1177                "client_id": claims.aud,
1178                "sub": claims.sub,
1179                "iss": claims.iss,
1180                "aud": claims.aud,
1181                "iat": claims.iat,
1182                "exp": claims.exp,
1183                "jti": claims.jti,
1184                "token_type": "Bearer",
1185                "azp": claims.azp,
1186                "auth_time": claims.auth_time,
1187                "sid": claims.sid,
1188            });
1189
1190            // Only include scope if present
1191            if claims.scope.is_none() {
1192                response.as_object_mut().unwrap().remove("scope");
1193            }
1194
1195            (StatusCode::OK, Json(response)).into_response()
1196        }
1197        Err(err) => {
1198            println!(">>> err {:?} ", err);
1199
1200            // If we can't decode it at all, it's invalid
1201            (StatusCode::OK, Json(json!({"active": false}))).into_response()
1202        }
1203    }
1204}
1205
1206async fn revoke(
1207    State(state): State<AppState>,
1208    Form(form): Form<HashMap<String, String>>,
1209) -> impl IntoResponse {
1210    let token = form.get("token").cloned().unwrap_or_default();
1211    if let Some(t) = state.tokens.write().unwrap().get_mut(&token) {
1212        t.revoked = true;
1213    }
1214    if let Some(t) = state.refresh_tokens.write().unwrap().get_mut(&token) {
1215        t.revoked = true;
1216    }
1217    (StatusCode::OK, Json(json!({}))).into_response()
1218}
1219
1220async fn userinfo(headers: HeaderMap, State(state): State<AppState>) -> impl IntoResponse {
1221    let auth = headers.get("Authorization").and_then(|v| v.to_str().ok());
1222    if let Some(auth) = auth {
1223        if let Some(token) = auth.strip_prefix("Bearer ") {
1224            if let Some(t) = state.tokens.read().unwrap().get(token) {
1225                if t.revoked || t.expires_at < Utc::now() {
1226                    return (
1227                        StatusCode::UNAUTHORIZED,
1228                        Json(json!({"error": "invalid_token"})),
1229                    )
1230                        .into_response();
1231                }
1232                let response = json!({
1233                    "sub": t.user_id,
1234                    "name": "Test User",
1235                    "email": "test@example.com",
1236                    "picture": "https://example.com/avatar.jpg"
1237                });
1238                return (StatusCode::OK, Json(response)).into_response();
1239            }
1240        }
1241    }
1242    (
1243        StatusCode::UNAUTHORIZED,
1244        Json(json!({"error": "invalid_token"})),
1245    )
1246        .into_response()
1247}
1248
1249async fn error_page(Query(params): Query<HashMap<String, String>>) -> Html<String> {
1250    let error = params.get("error").map(|s| s.as_str()).unwrap_or("unknown");
1251    Html(format!("<h1>OAuth Error: {}</h1>", error))
1252}
1253
1254fn build_cors_layer(config: &IssuerConfig) -> CorsLayer {
1255    let allowed_origins: Vec<HeaderValue> = config
1256        .allowed_origins
1257        .iter()
1258        .filter_map(|s| s.parse().ok())
1259        .collect();
1260
1261    let allowed = if allowed_origins.is_empty() {
1262        AllowOrigin::any()
1263    } else {
1264        AllowOrigin::list(allowed_origins)
1265    };
1266
1267    let allowed_methods =
1268        AllowMethods::list([http::Method::GET, http::Method::POST, http::Method::OPTIONS]);
1269
1270    let allowed_headers = AllowHeaders::list([
1271        header::AUTHORIZATION,
1272        header::CONTENT_TYPE,
1273        header::ACCEPT,
1274        "x-requested-with".parse().unwrap(),
1275    ]);
1276
1277    CorsLayer::new()
1278        .allow_origin(allowed)
1279        .allow_methods(allowed_methods)
1280        .allow_headers(allowed_headers)
1281        .allow_credentials(true)
1282        .max_age(std::time::Duration::from_secs(86400))
1283}