Skip to main content

fastmcp_server/
oauth.rs

1//! OAuth 2.0/2.1 Authorization Server for MCP.
2//!
3//! This module implements a complete OAuth 2.0/2.1 authorization server for MCP
4//! servers, providing:
5//!
6//! - **Authorization Code Flow** with PKCE (required for OAuth 2.1)
7//! - **Token Issuance** - Access tokens and refresh tokens
8//! - **Token Revocation** - RFC 7009 token revocation
9//! - **Client Registration** - Dynamic client registration
10//! - **Scope Validation** - Fine-grained scope control
11//! - **Redirect URI Validation** - Security-critical validation
12//!
13//! # Architecture
14//!
15//! The OAuth server is designed to be modular:
16//!
17//! - [`OAuthServer`]: Main authorization server component
18//! - [`OAuthClient`]: Registered OAuth client
19//! - [`AuthorizationCode`]: Temporary code for token exchange
20//! - [`OAuthToken`]: Access and refresh tokens
21//! - [`OAuthTokenVerifier`]: Implements [`TokenVerifier`] for MCP integration
22//!
23//! # Security Considerations
24//!
25//! - PKCE is **required** for all authorization code flows (OAuth 2.1 compliance)
26//! - Redirect URIs must be exact matches or localhost with any port
27//! - Tokens are cryptographically random and securely generated
28//! - Authorization codes are single-use and expire quickly
29//!
30//! # Example
31//!
32//! ```ignore
33//! use fastmcp::oauth::{OAuthServer, OAuthServerConfig, OAuthClient};
34//!
35//! let oauth = OAuthServer::new(OAuthServerConfig::default());
36//!
37//! // Register a client
38//! let client = OAuthClient::builder("my-client")
39//!     .redirect_uri("http://localhost:3000/callback")
40//!     .scope("read")
41//!     .scope("write")
42//!     .build()?;
43//!
44//! oauth.register_client(client)?;
45//!
46//! // Use with MCP server
47//! let verifier = oauth.token_verifier();
48//! Server::new("my-server", "1.0.0")
49//!     .auth_provider(TokenAuthProvider::new(verifier))
50//!     .run_stdio();
51//! ```
52
53use std::collections::{HashMap, HashSet};
54use std::sync::{Arc, RwLock};
55use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
56
57use fastmcp_core::{AccessToken, AuthContext, McpContext, McpError, McpErrorCode, McpResult};
58
59use crate::auth::{AuthRequest, TokenVerifier};
60
61// =============================================================================
62// Configuration
63// =============================================================================
64
65/// Configuration for the OAuth authorization server.
66#[derive(Debug, Clone)]
67pub struct OAuthServerConfig {
68    /// Issuer identifier (URL) for this authorization server.
69    pub issuer: String,
70    /// Access token lifetime.
71    pub access_token_lifetime: Duration,
72    /// Refresh token lifetime.
73    pub refresh_token_lifetime: Duration,
74    /// Authorization code lifetime (should be short, e.g., 10 minutes).
75    pub authorization_code_lifetime: Duration,
76    /// Whether to allow public clients (clients without a secret).
77    pub allow_public_clients: bool,
78    /// Minimum PKCE code verifier length (default: 43, min: 43, max: 128).
79    pub min_code_verifier_length: usize,
80    /// Maximum PKCE code verifier length.
81    pub max_code_verifier_length: usize,
82    /// Token entropy bytes (default: 32 = 256 bits).
83    pub token_entropy_bytes: usize,
84}
85
86impl Default for OAuthServerConfig {
87    fn default() -> Self {
88        Self {
89            issuer: "fastmcp".to_string(),
90            access_token_lifetime: Duration::from_secs(3600), // 1 hour
91            refresh_token_lifetime: Duration::from_secs(86400 * 30), // 30 days
92            authorization_code_lifetime: Duration::from_secs(600), // 10 minutes
93            allow_public_clients: true,
94            min_code_verifier_length: 43,
95            max_code_verifier_length: 128,
96            token_entropy_bytes: 32,
97        }
98    }
99}
100
101// =============================================================================
102// OAuth Client
103// =============================================================================
104
105/// OAuth client types.
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum ClientType {
108    /// Confidential client (has a secret).
109    Confidential,
110    /// Public client (no secret, e.g., native apps, SPAs).
111    Public,
112}
113
114/// A registered OAuth client.
115#[derive(Debug, Clone)]
116pub struct OAuthClient {
117    /// Unique client identifier.
118    pub client_id: String,
119    /// Client secret (None for public clients).
120    pub client_secret: Option<String>,
121    /// Client type.
122    pub client_type: ClientType,
123    /// Allowed redirect URIs.
124    pub redirect_uris: Vec<String>,
125    /// Allowed scopes.
126    pub allowed_scopes: HashSet<String>,
127    /// Client name (for display).
128    pub name: Option<String>,
129    /// Client description.
130    pub description: Option<String>,
131    /// When the client was registered.
132    pub registered_at: SystemTime,
133}
134
135impl OAuthClient {
136    /// Creates a new client builder.
137    #[must_use]
138    pub fn builder(client_id: impl Into<String>) -> OAuthClientBuilder {
139        OAuthClientBuilder::new(client_id)
140    }
141
142    /// Validates that a redirect URI is allowed for this client.
143    #[must_use]
144    pub fn validate_redirect_uri(&self, uri: &str) -> bool {
145        // Check for exact match first
146        if self.redirect_uris.contains(&uri.to_string()) {
147            return true;
148        }
149
150        // For localhost URIs, allow any port (OAuth 2.0 for Native Apps, RFC 8252)
151        for allowed in &self.redirect_uris {
152            if is_localhost_redirect(allowed) && is_localhost_redirect(uri) {
153                // Compare scheme and path, allow different ports
154                if localhost_match(allowed, uri) {
155                    return true;
156                }
157            }
158        }
159
160        false
161    }
162
163    /// Validates that the requested scopes are allowed for this client.
164    #[must_use]
165    pub fn validate_scopes(&self, scopes: &[String]) -> bool {
166        scopes.iter().all(|s| self.allowed_scopes.contains(s))
167    }
168
169    /// Authenticates a confidential client.
170    #[must_use]
171    pub fn authenticate(&self, secret: Option<&str>) -> bool {
172        match (&self.client_secret, secret) {
173            (Some(expected), Some(provided)) => constant_time_eq(expected, provided),
174            (None, None) => self.client_type == ClientType::Public,
175            _ => false,
176        }
177    }
178}
179
180/// Builder for OAuth clients.
181#[derive(Debug)]
182pub struct OAuthClientBuilder {
183    client_id: String,
184    client_secret: Option<String>,
185    redirect_uris: Vec<String>,
186    allowed_scopes: HashSet<String>,
187    name: Option<String>,
188    description: Option<String>,
189}
190
191impl OAuthClientBuilder {
192    /// Creates a new client builder.
193    fn new(client_id: impl Into<String>) -> Self {
194        Self {
195            client_id: client_id.into(),
196            client_secret: None,
197            redirect_uris: Vec::new(),
198            allowed_scopes: HashSet::new(),
199            name: None,
200            description: None,
201        }
202    }
203
204    /// Sets the client secret (makes this a confidential client).
205    #[must_use]
206    pub fn secret(mut self, secret: impl Into<String>) -> Self {
207        self.client_secret = Some(secret.into());
208        self
209    }
210
211    /// Adds a redirect URI.
212    #[must_use]
213    pub fn redirect_uri(mut self, uri: impl Into<String>) -> Self {
214        self.redirect_uris.push(uri.into());
215        self
216    }
217
218    /// Adds multiple redirect URIs.
219    #[must_use]
220    pub fn redirect_uris<I, S>(mut self, uris: I) -> Self
221    where
222        I: IntoIterator<Item = S>,
223        S: Into<String>,
224    {
225        self.redirect_uris.extend(uris.into_iter().map(Into::into));
226        self
227    }
228
229    /// Adds an allowed scope.
230    #[must_use]
231    pub fn scope(mut self, scope: impl Into<String>) -> Self {
232        self.allowed_scopes.insert(scope.into());
233        self
234    }
235
236    /// Adds multiple allowed scopes.
237    #[must_use]
238    pub fn scopes<I, S>(mut self, scopes: I) -> Self
239    where
240        I: IntoIterator<Item = S>,
241        S: Into<String>,
242    {
243        self.allowed_scopes
244            .extend(scopes.into_iter().map(Into::into));
245        self
246    }
247
248    /// Sets the client name.
249    #[must_use]
250    pub fn name(mut self, name: impl Into<String>) -> Self {
251        self.name = Some(name.into());
252        self
253    }
254
255    /// Sets the client description.
256    #[must_use]
257    pub fn description(mut self, description: impl Into<String>) -> Self {
258        self.description = Some(description.into());
259        self
260    }
261
262    /// Builds the OAuth client.
263    ///
264    /// # Errors
265    ///
266    /// Returns an error if:
267    /// - No redirect URIs are configured
268    /// - Client ID is empty
269    pub fn build(self) -> Result<OAuthClient, OAuthError> {
270        if self.client_id.is_empty() {
271            return Err(OAuthError::InvalidRequest(
272                "client_id cannot be empty".to_string(),
273            ));
274        }
275
276        if self.redirect_uris.is_empty() {
277            return Err(OAuthError::InvalidRequest(
278                "at least one redirect_uri is required".to_string(),
279            ));
280        }
281
282        let client_type = if self.client_secret.is_some() {
283            ClientType::Confidential
284        } else {
285            ClientType::Public
286        };
287
288        Ok(OAuthClient {
289            client_id: self.client_id,
290            client_secret: self.client_secret,
291            client_type,
292            redirect_uris: self.redirect_uris,
293            allowed_scopes: self.allowed_scopes,
294            name: self.name,
295            description: self.description,
296            registered_at: SystemTime::now(),
297        })
298    }
299}
300
301// =============================================================================
302// Authorization Code
303// =============================================================================
304
305/// PKCE code challenge method.
306#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub enum CodeChallengeMethod {
308    /// Plain text (not recommended, but allowed for compatibility).
309    Plain,
310    /// SHA-256 hash (recommended).
311    S256,
312}
313
314impl CodeChallengeMethod {
315    /// Parses a code challenge method from a string.
316    #[must_use]
317    pub fn parse(s: &str) -> Option<Self> {
318        match s {
319            "plain" => Some(Self::Plain),
320            "S256" => Some(Self::S256),
321            _ => None,
322        }
323    }
324
325    /// Returns the string representation.
326    #[must_use]
327    pub fn as_str(&self) -> &'static str {
328        match self {
329            Self::Plain => "plain",
330            Self::S256 => "S256",
331        }
332    }
333}
334
335/// Authorization code issued during the authorization flow.
336#[derive(Debug, Clone)]
337pub struct AuthorizationCode {
338    /// The code value.
339    pub code: String,
340    /// Client ID this code was issued to.
341    pub client_id: String,
342    /// Redirect URI used in the authorization request.
343    pub redirect_uri: String,
344    /// Approved scopes.
345    pub scopes: Vec<String>,
346    /// PKCE code challenge.
347    pub code_challenge: String,
348    /// PKCE code challenge method.
349    pub code_challenge_method: CodeChallengeMethod,
350    /// When the code was issued.
351    pub issued_at: Instant,
352    /// When the code expires.
353    pub expires_at: Instant,
354    /// Subject (user) this code was issued for.
355    pub subject: Option<String>,
356    /// State parameter from the authorization request.
357    pub state: Option<String>,
358}
359
360impl AuthorizationCode {
361    /// Checks if this code has expired.
362    #[must_use]
363    pub fn is_expired(&self) -> bool {
364        Instant::now() >= self.expires_at
365    }
366
367    /// Validates the PKCE code verifier against the stored challenge.
368    #[must_use]
369    pub fn validate_code_verifier(&self, verifier: &str) -> bool {
370        match self.code_challenge_method {
371            CodeChallengeMethod::Plain => constant_time_eq(&self.code_challenge, verifier),
372            CodeChallengeMethod::S256 => {
373                let computed = compute_s256_challenge(verifier);
374                constant_time_eq(&self.code_challenge, &computed)
375            }
376        }
377    }
378}
379
380// =============================================================================
381// OAuth Tokens
382// =============================================================================
383
384/// Token type.
385#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386pub enum TokenType {
387    /// Bearer token.
388    Bearer,
389}
390
391impl TokenType {
392    /// Returns the string representation.
393    #[must_use]
394    pub fn as_str(&self) -> &'static str {
395        match self {
396            Self::Bearer => "bearer",
397        }
398    }
399}
400
401/// OAuth token (access or refresh).
402#[derive(Debug, Clone)]
403pub struct OAuthToken {
404    /// Token value.
405    pub token: String,
406    /// Token type.
407    pub token_type: TokenType,
408    /// Client ID this token was issued to.
409    pub client_id: String,
410    /// Approved scopes.
411    pub scopes: Vec<String>,
412    /// When the token was issued.
413    pub issued_at: Instant,
414    /// When the token expires.
415    pub expires_at: Instant,
416    /// Subject (user) this token was issued for.
417    pub subject: Option<String>,
418    /// Whether this is a refresh token.
419    pub is_refresh_token: bool,
420}
421
422impl OAuthToken {
423    /// Checks if this token has expired.
424    #[must_use]
425    pub fn is_expired(&self) -> bool {
426        Instant::now() >= self.expires_at
427    }
428
429    /// Returns the remaining lifetime in seconds.
430    #[must_use]
431    pub fn expires_in_secs(&self) -> u64 {
432        self.expires_at
433            .saturating_duration_since(Instant::now())
434            .as_secs()
435    }
436}
437
438/// Token response for successful token issuance.
439#[derive(Debug, Clone, serde::Serialize)]
440pub struct TokenResponse {
441    /// The access token.
442    pub access_token: String,
443    /// Token type (always "bearer").
444    pub token_type: String,
445    /// Token lifetime in seconds.
446    pub expires_in: u64,
447    /// Refresh token (if issued).
448    #[serde(skip_serializing_if = "Option::is_none")]
449    pub refresh_token: Option<String>,
450    /// Granted scopes (space-separated).
451    #[serde(skip_serializing_if = "Option::is_none")]
452    pub scope: Option<String>,
453}
454
455// =============================================================================
456// Authorization Request
457// =============================================================================
458
459/// Authorization request parameters.
460#[derive(Debug, Clone)]
461pub struct AuthorizationRequest {
462    /// Response type (must be "code" for authorization code flow).
463    pub response_type: String,
464    /// Client ID.
465    pub client_id: String,
466    /// Redirect URI.
467    pub redirect_uri: String,
468    /// Requested scopes (space-separated in original request).
469    pub scopes: Vec<String>,
470    /// State parameter (recommended for CSRF protection).
471    pub state: Option<String>,
472    /// PKCE code challenge.
473    pub code_challenge: String,
474    /// PKCE code challenge method.
475    pub code_challenge_method: CodeChallengeMethod,
476}
477
478/// Token request parameters.
479#[derive(Debug, Clone)]
480pub struct TokenRequest {
481    /// Grant type.
482    pub grant_type: String,
483    /// Authorization code (for authorization_code grant).
484    pub code: Option<String>,
485    /// Redirect URI (for authorization_code grant).
486    pub redirect_uri: Option<String>,
487    /// Client ID.
488    pub client_id: String,
489    /// Client secret (for confidential clients).
490    pub client_secret: Option<String>,
491    /// PKCE code verifier.
492    pub code_verifier: Option<String>,
493    /// Refresh token (for refresh_token grant).
494    pub refresh_token: Option<String>,
495    /// Requested scopes (for refresh_token grant, subset of original scopes).
496    pub scopes: Option<Vec<String>>,
497}
498
499// =============================================================================
500// OAuth Errors
501// =============================================================================
502
503/// OAuth error types following RFC 6749.
504#[derive(Debug, Clone)]
505pub enum OAuthError {
506    /// The request is missing a required parameter or is otherwise malformed.
507    InvalidRequest(String),
508    /// Client authentication failed.
509    InvalidClient(String),
510    /// The authorization grant or refresh token is invalid.
511    InvalidGrant(String),
512    /// The client is not authorized to use this grant type.
513    UnauthorizedClient(String),
514    /// The grant type is not supported.
515    UnsupportedGrantType(String),
516    /// The requested scope is invalid or unknown.
517    InvalidScope(String),
518    /// The authorization server encountered an unexpected condition.
519    ServerError(String),
520    /// The authorization server is temporarily unavailable.
521    TemporarilyUnavailable(String),
522    /// Access denied by the resource owner.
523    AccessDenied(String),
524    /// The response type is not supported.
525    UnsupportedResponseType(String),
526}
527
528impl OAuthError {
529    /// Returns the OAuth error code.
530    #[must_use]
531    pub fn error_code(&self) -> &'static str {
532        match self {
533            Self::InvalidRequest(_) => "invalid_request",
534            Self::InvalidClient(_) => "invalid_client",
535            Self::InvalidGrant(_) => "invalid_grant",
536            Self::UnauthorizedClient(_) => "unauthorized_client",
537            Self::UnsupportedGrantType(_) => "unsupported_grant_type",
538            Self::InvalidScope(_) => "invalid_scope",
539            Self::ServerError(_) => "server_error",
540            Self::TemporarilyUnavailable(_) => "temporarily_unavailable",
541            Self::AccessDenied(_) => "access_denied",
542            Self::UnsupportedResponseType(_) => "unsupported_response_type",
543        }
544    }
545
546    /// Returns the error description.
547    #[must_use]
548    pub fn description(&self) -> &str {
549        match self {
550            Self::InvalidRequest(s)
551            | Self::InvalidClient(s)
552            | Self::InvalidGrant(s)
553            | Self::UnauthorizedClient(s)
554            | Self::UnsupportedGrantType(s)
555            | Self::InvalidScope(s)
556            | Self::ServerError(s)
557            | Self::TemporarilyUnavailable(s)
558            | Self::AccessDenied(s)
559            | Self::UnsupportedResponseType(s) => s,
560        }
561    }
562}
563
564impl std::fmt::Display for OAuthError {
565    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566        write!(f, "{}: {}", self.error_code(), self.description())
567    }
568}
569
570impl std::error::Error for OAuthError {}
571
572impl From<OAuthError> for McpError {
573    fn from(err: OAuthError) -> Self {
574        match &err {
575            OAuthError::InvalidClient(_) | OAuthError::UnauthorizedClient(_) => {
576                McpError::new(McpErrorCode::ResourceForbidden, err.to_string())
577            }
578            OAuthError::AccessDenied(_) => {
579                McpError::new(McpErrorCode::ResourceForbidden, err.to_string())
580            }
581            _ => McpError::new(McpErrorCode::InvalidRequest, err.to_string()),
582        }
583    }
584}
585
586// =============================================================================
587// OAuth Server
588// =============================================================================
589
590/// Internal state for the OAuth server.
591pub(crate) struct OAuthServerState {
592    /// Registered clients by client_id.
593    pub(crate) clients: HashMap<String, OAuthClient>,
594    /// Pending authorization codes.
595    pub(crate) authorization_codes: HashMap<String, AuthorizationCode>,
596    /// Active access tokens.
597    pub(crate) access_tokens: HashMap<String, OAuthToken>,
598    /// Active refresh tokens.
599    pub(crate) refresh_tokens: HashMap<String, OAuthToken>,
600    /// Revoked tokens (for revocation checking).
601    pub(crate) revoked_tokens: HashSet<String>,
602}
603
604impl OAuthServerState {
605    fn new() -> Self {
606        Self {
607            clients: HashMap::new(),
608            authorization_codes: HashMap::new(),
609            access_tokens: HashMap::new(),
610            refresh_tokens: HashMap::new(),
611            revoked_tokens: HashSet::new(),
612        }
613    }
614}
615
616/// OAuth 2.0/2.1 authorization server.
617///
618/// This server implements the OAuth 2.0 authorization code flow with PKCE,
619/// which is required for OAuth 2.1 compliance.
620pub struct OAuthServer {
621    config: OAuthServerConfig,
622    pub(crate) state: RwLock<OAuthServerState>,
623}
624
625impl OAuthServer {
626    /// Creates a new OAuth server with the given configuration.
627    #[must_use]
628    pub fn new(config: OAuthServerConfig) -> Self {
629        Self {
630            config,
631            state: RwLock::new(OAuthServerState::new()),
632        }
633    }
634
635    /// Creates a new OAuth server with default configuration.
636    #[must_use]
637    pub fn with_defaults() -> Self {
638        Self::new(OAuthServerConfig::default())
639    }
640
641    /// Returns the server configuration.
642    #[must_use]
643    pub fn config(&self) -> &OAuthServerConfig {
644        &self.config
645    }
646
647    // -------------------------------------------------------------------------
648    // Client Registration
649    // -------------------------------------------------------------------------
650
651    /// Registers a new OAuth client.
652    ///
653    /// # Errors
654    ///
655    /// Returns an error if:
656    /// - A client with the same ID already exists
657    /// - Public clients are not allowed and the client has no secret
658    pub fn register_client(&self, client: OAuthClient) -> Result<(), OAuthError> {
659        if client.client_type == ClientType::Public && !self.config.allow_public_clients {
660            return Err(OAuthError::InvalidClient(
661                "public clients are not allowed".to_string(),
662            ));
663        }
664
665        let mut state = self
666            .state
667            .write()
668            .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
669
670        if state.clients.contains_key(&client.client_id) {
671            return Err(OAuthError::InvalidClient(format!(
672                "client '{}' already exists",
673                client.client_id
674            )));
675        }
676
677        state.clients.insert(client.client_id.clone(), client);
678        Ok(())
679    }
680
681    /// Unregisters an OAuth client.
682    ///
683    /// This also revokes all tokens issued to the client.
684    pub fn unregister_client(&self, client_id: &str) -> Result<(), OAuthError> {
685        let mut state = self
686            .state
687            .write()
688            .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
689
690        if state.clients.remove(client_id).is_none() {
691            return Err(OAuthError::InvalidClient(format!(
692                "client '{}' not found",
693                client_id
694            )));
695        }
696
697        // Revoke all tokens for this client
698        let access_tokens: Vec<_> = state
699            .access_tokens
700            .iter()
701            .filter(|(_, t)| t.client_id == client_id)
702            .map(|(k, _)| k.clone())
703            .collect();
704        for token in access_tokens {
705            state.access_tokens.remove(&token);
706            state.revoked_tokens.insert(token);
707        }
708
709        let refresh_tokens: Vec<_> = state
710            .refresh_tokens
711            .iter()
712            .filter(|(_, t)| t.client_id == client_id)
713            .map(|(k, _)| k.clone())
714            .collect();
715        for token in refresh_tokens {
716            state.refresh_tokens.remove(&token);
717            state.revoked_tokens.insert(token);
718        }
719
720        // Remove pending authorization codes
721        let codes: Vec<_> = state
722            .authorization_codes
723            .iter()
724            .filter(|(_, c)| c.client_id == client_id)
725            .map(|(k, _)| k.clone())
726            .collect();
727        for code in codes {
728            state.authorization_codes.remove(&code);
729        }
730
731        Ok(())
732    }
733
734    /// Gets a registered client by ID.
735    #[must_use]
736    pub fn get_client(&self, client_id: &str) -> Option<OAuthClient> {
737        self.state
738            .read()
739            .ok()
740            .and_then(|s| s.clients.get(client_id).cloned())
741    }
742
743    /// Lists all registered clients.
744    #[must_use]
745    pub fn list_clients(&self) -> Vec<OAuthClient> {
746        self.state
747            .read()
748            .map(|s| s.clients.values().cloned().collect())
749            .unwrap_or_default()
750    }
751
752    // -------------------------------------------------------------------------
753    // Authorization Endpoint
754    // -------------------------------------------------------------------------
755
756    /// Validates an authorization request and creates an authorization code.
757    ///
758    /// This is called after the resource owner has authenticated and approved
759    /// the authorization request.
760    ///
761    /// # Arguments
762    ///
763    /// * `request` - The authorization request parameters
764    /// * `subject` - The authenticated user's identifier (optional)
765    ///
766    /// # Returns
767    ///
768    /// Returns the authorization code and redirect URI on success.
769    pub fn authorize(
770        &self,
771        request: &AuthorizationRequest,
772        subject: Option<String>,
773    ) -> Result<(String, String), OAuthError> {
774        // Validate response_type
775        if request.response_type != "code" {
776            return Err(OAuthError::UnsupportedResponseType(
777                "only 'code' response_type is supported".to_string(),
778            ));
779        }
780
781        // Get and validate client
782        let client = self.get_client(&request.client_id).ok_or_else(|| {
783            OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
784        })?;
785
786        // Validate redirect URI
787        if !client.validate_redirect_uri(&request.redirect_uri) {
788            return Err(OAuthError::InvalidRequest(
789                "invalid redirect_uri".to_string(),
790            ));
791        }
792
793        // Validate scopes
794        if !client.validate_scopes(&request.scopes) {
795            return Err(OAuthError::InvalidScope(
796                "requested scope not allowed".to_string(),
797            ));
798        }
799
800        // Validate PKCE (required for OAuth 2.1)
801        if request.code_challenge.is_empty() {
802            return Err(OAuthError::InvalidRequest(
803                "code_challenge is required (PKCE)".to_string(),
804            ));
805        }
806
807        // Generate authorization code
808        let code_value = generate_token(self.config.token_entropy_bytes);
809        let now = Instant::now();
810        let code = AuthorizationCode {
811            code: code_value.clone(),
812            client_id: request.client_id.clone(),
813            redirect_uri: request.redirect_uri.clone(),
814            scopes: request.scopes.clone(),
815            code_challenge: request.code_challenge.clone(),
816            code_challenge_method: request.code_challenge_method,
817            issued_at: now,
818            expires_at: now + self.config.authorization_code_lifetime,
819            subject,
820            state: request.state.clone(),
821        };
822
823        // Store the code
824        {
825            let mut state = self
826                .state
827                .write()
828                .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
829            state.authorization_codes.insert(code_value.clone(), code);
830        }
831
832        // Build redirect URI with code
833        let mut redirect = request.redirect_uri.clone();
834        let separator = if redirect.contains('?') { '&' } else { '?' };
835        redirect.push(separator);
836        redirect.push_str("code=");
837        redirect.push_str(&url_encode(&code_value));
838        if let Some(state) = &request.state {
839            redirect.push_str("&state=");
840            redirect.push_str(&url_encode(state));
841        }
842
843        Ok((code_value, redirect))
844    }
845
846    // -------------------------------------------------------------------------
847    // Token Endpoint
848    // -------------------------------------------------------------------------
849
850    /// Exchanges an authorization code or refresh token for tokens.
851    pub fn token(&self, request: &TokenRequest) -> Result<TokenResponse, OAuthError> {
852        match request.grant_type.as_str() {
853            "authorization_code" => self.token_authorization_code(request),
854            "refresh_token" => self.token_refresh_token(request),
855            other => Err(OAuthError::UnsupportedGrantType(format!(
856                "grant_type '{}' is not supported",
857                other
858            ))),
859        }
860    }
861
862    fn token_authorization_code(
863        &self,
864        request: &TokenRequest,
865    ) -> Result<TokenResponse, OAuthError> {
866        // Validate required parameters
867        let code_value = request
868            .code
869            .as_ref()
870            .ok_or_else(|| OAuthError::InvalidRequest("code is required".to_string()))?;
871        let redirect_uri = request
872            .redirect_uri
873            .as_ref()
874            .ok_or_else(|| OAuthError::InvalidRequest("redirect_uri is required".to_string()))?;
875        let code_verifier = request.code_verifier.as_ref().ok_or_else(|| {
876            OAuthError::InvalidRequest("code_verifier is required (PKCE)".to_string())
877        })?;
878
879        // Validate code verifier length
880        if code_verifier.len() < self.config.min_code_verifier_length
881            || code_verifier.len() > self.config.max_code_verifier_length
882        {
883            return Err(OAuthError::InvalidRequest(format!(
884                "code_verifier must be between {} and {} characters",
885                self.config.min_code_verifier_length, self.config.max_code_verifier_length
886            )));
887        }
888
889        // Get and validate authorization code
890        let auth_code = {
891            let mut state = self
892                .state
893                .write()
894                .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
895
896            // Remove the code (single-use)
897            state
898                .authorization_codes
899                .remove(code_value)
900                .ok_or_else(|| {
901                    OAuthError::InvalidGrant(
902                        "authorization code not found or already used".to_string(),
903                    )
904                })?
905        };
906
907        // Validate the code
908        if auth_code.is_expired() {
909            return Err(OAuthError::InvalidGrant(
910                "authorization code has expired".to_string(),
911            ));
912        }
913        if auth_code.client_id != request.client_id {
914            return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
915        }
916        if auth_code.redirect_uri != *redirect_uri {
917            return Err(OAuthError::InvalidGrant(
918                "redirect_uri mismatch".to_string(),
919            ));
920        }
921
922        // Validate PKCE
923        if !auth_code.validate_code_verifier(code_verifier) {
924            return Err(OAuthError::InvalidGrant(
925                "code_verifier validation failed".to_string(),
926            ));
927        }
928
929        // Authenticate client (if confidential)
930        let client = self.get_client(&request.client_id).ok_or_else(|| {
931            OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
932        })?;
933
934        if client.client_type == ClientType::Confidential {
935            if !client.authenticate(request.client_secret.as_deref()) {
936                return Err(OAuthError::InvalidClient(
937                    "client authentication failed".to_string(),
938                ));
939            }
940        }
941
942        // Issue tokens
943        self.issue_tokens(
944            &auth_code.client_id,
945            &auth_code.scopes,
946            auth_code.subject.as_deref(),
947        )
948    }
949
950    fn token_refresh_token(&self, request: &TokenRequest) -> Result<TokenResponse, OAuthError> {
951        let refresh_token_value = request
952            .refresh_token
953            .as_ref()
954            .ok_or_else(|| OAuthError::InvalidRequest("refresh_token is required".to_string()))?;
955
956        // Get and validate refresh token
957        let refresh_token = {
958            let state = self
959                .state
960                .read()
961                .map_err(|_| OAuthError::ServerError("failed to acquire read lock".to_string()))?;
962
963            // Check if revoked
964            if state.revoked_tokens.contains(refresh_token_value) {
965                return Err(OAuthError::InvalidGrant(
966                    "refresh token has been revoked".to_string(),
967                ));
968            }
969
970            state
971                .refresh_tokens
972                .get(refresh_token_value)
973                .cloned()
974                .ok_or_else(|| OAuthError::InvalidGrant("refresh token not found".to_string()))?
975        };
976
977        if refresh_token.is_expired() {
978            return Err(OAuthError::InvalidGrant(
979                "refresh token has expired".to_string(),
980            ));
981        }
982        if refresh_token.client_id != request.client_id {
983            return Err(OAuthError::InvalidGrant("client_id mismatch".to_string()));
984        }
985
986        // Authenticate client
987        let client = self.get_client(&request.client_id).ok_or_else(|| {
988            OAuthError::InvalidClient(format!("client '{}' not found", request.client_id))
989        })?;
990
991        if client.client_type == ClientType::Confidential {
992            if !client.authenticate(request.client_secret.as_deref()) {
993                return Err(OAuthError::InvalidClient(
994                    "client authentication failed".to_string(),
995                ));
996            }
997        }
998
999        // Determine scopes (subset of original if specified)
1000        let scopes = if let Some(requested) = &request.scopes {
1001            // Validate that requested scopes are a subset of original
1002            for scope in requested {
1003                if !refresh_token.scopes.contains(scope) {
1004                    return Err(OAuthError::InvalidScope(format!(
1005                        "scope '{}' was not in original grant",
1006                        scope
1007                    )));
1008                }
1009            }
1010            requested.clone()
1011        } else {
1012            refresh_token.scopes.clone()
1013        };
1014
1015        // Issue new access token (keep same refresh token)
1016        let now = Instant::now();
1017        let access_token_value = generate_token(self.config.token_entropy_bytes);
1018        let access_token = OAuthToken {
1019            token: access_token_value.clone(),
1020            token_type: TokenType::Bearer,
1021            client_id: request.client_id.clone(),
1022            scopes: scopes.clone(),
1023            issued_at: now,
1024            expires_at: now + self.config.access_token_lifetime,
1025            subject: refresh_token.subject.clone(),
1026            is_refresh_token: false,
1027        };
1028
1029        // Store new access token
1030        {
1031            let mut state = self
1032                .state
1033                .write()
1034                .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1035            state
1036                .access_tokens
1037                .insert(access_token_value.clone(), access_token.clone());
1038        }
1039
1040        Ok(TokenResponse {
1041            access_token: access_token_value,
1042            token_type: access_token.token_type.as_str().to_string(),
1043            expires_in: access_token.expires_in_secs(),
1044            refresh_token: None, // Don't issue new refresh token
1045            scope: if scopes.is_empty() {
1046                None
1047            } else {
1048                Some(scopes.join(" "))
1049            },
1050        })
1051    }
1052
1053    fn issue_tokens(
1054        &self,
1055        client_id: &str,
1056        scopes: &[String],
1057        subject: Option<&str>,
1058    ) -> Result<TokenResponse, OAuthError> {
1059        let now = Instant::now();
1060
1061        // Generate access token
1062        let access_token_value = generate_token(self.config.token_entropy_bytes);
1063        let access_token = OAuthToken {
1064            token: access_token_value.clone(),
1065            token_type: TokenType::Bearer,
1066            client_id: client_id.to_string(),
1067            scopes: scopes.to_vec(),
1068            issued_at: now,
1069            expires_at: now + self.config.access_token_lifetime,
1070            subject: subject.map(String::from),
1071            is_refresh_token: false,
1072        };
1073
1074        // Generate refresh token
1075        let refresh_token_value = generate_token(self.config.token_entropy_bytes);
1076        let refresh_token = OAuthToken {
1077            token: refresh_token_value.clone(),
1078            token_type: TokenType::Bearer,
1079            client_id: client_id.to_string(),
1080            scopes: scopes.to_vec(),
1081            issued_at: now,
1082            expires_at: now + self.config.refresh_token_lifetime,
1083            subject: subject.map(String::from),
1084            is_refresh_token: true,
1085        };
1086
1087        // Store tokens
1088        {
1089            let mut state = self
1090                .state
1091                .write()
1092                .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1093            state
1094                .access_tokens
1095                .insert(access_token_value.clone(), access_token.clone());
1096            state
1097                .refresh_tokens
1098                .insert(refresh_token_value.clone(), refresh_token);
1099        }
1100
1101        Ok(TokenResponse {
1102            access_token: access_token_value,
1103            token_type: access_token.token_type.as_str().to_string(),
1104            expires_in: access_token.expires_in_secs(),
1105            refresh_token: Some(refresh_token_value),
1106            scope: if scopes.is_empty() {
1107                None
1108            } else {
1109                Some(scopes.join(" "))
1110            },
1111        })
1112    }
1113
1114    // -------------------------------------------------------------------------
1115    // Token Revocation (RFC 7009)
1116    // -------------------------------------------------------------------------
1117
1118    /// Revokes a token (access or refresh).
1119    ///
1120    /// Per RFC 7009, this always returns success even if the token was not found.
1121    pub fn revoke(
1122        &self,
1123        token: &str,
1124        client_id: &str,
1125        client_secret: Option<&str>,
1126    ) -> Result<(), OAuthError> {
1127        // Authenticate client
1128        let client = self.get_client(client_id).ok_or_else(|| {
1129            OAuthError::InvalidClient(format!("client '{}' not found", client_id))
1130        })?;
1131
1132        if client.client_type == ClientType::Confidential {
1133            if !client.authenticate(client_secret) {
1134                return Err(OAuthError::InvalidClient(
1135                    "client authentication failed".to_string(),
1136                ));
1137            }
1138        }
1139
1140        let mut state = self
1141            .state
1142            .write()
1143            .map_err(|_| OAuthError::ServerError("failed to acquire write lock".to_string()))?;
1144
1145        // Try to find and remove the token
1146        let found_access = state.access_tokens.remove(token);
1147        let found_refresh = state.refresh_tokens.remove(token);
1148
1149        // Validate client owns the token (if found)
1150        if let Some(ref t) = found_access {
1151            if t.client_id != client_id {
1152                // Don't reveal that the token exists but belongs to another client
1153                return Ok(());
1154            }
1155        }
1156        if let Some(ref t) = found_refresh {
1157            if t.client_id != client_id {
1158                return Ok(());
1159            }
1160        }
1161
1162        // Mark as revoked
1163        if found_access.is_some() || found_refresh.is_some() {
1164            state.revoked_tokens.insert(token.to_string());
1165        }
1166
1167        Ok(())
1168    }
1169
1170    // -------------------------------------------------------------------------
1171    // Token Introspection
1172    // -------------------------------------------------------------------------
1173
1174    /// Validates an access token and returns its metadata.
1175    ///
1176    /// This is used internally and by the [`OAuthTokenVerifier`].
1177    pub fn validate_access_token(&self, token: &str) -> Option<OAuthToken> {
1178        let state = self.state.read().ok()?;
1179
1180        // Check if revoked
1181        if state.revoked_tokens.contains(token) {
1182            return None;
1183        }
1184
1185        let token_info = state.access_tokens.get(token)?;
1186
1187        if token_info.is_expired() {
1188            return None;
1189        }
1190
1191        Some(token_info.clone())
1192    }
1193
1194    // -------------------------------------------------------------------------
1195    // MCP Integration
1196    // -------------------------------------------------------------------------
1197
1198    /// Creates a token verifier for use with MCP [`TokenAuthProvider`].
1199    #[must_use]
1200    pub fn token_verifier(self: &Arc<Self>) -> OAuthTokenVerifier {
1201        OAuthTokenVerifier {
1202            server: Arc::clone(self),
1203        }
1204    }
1205
1206    // -------------------------------------------------------------------------
1207    // Maintenance
1208    // -------------------------------------------------------------------------
1209
1210    /// Removes expired tokens and authorization codes.
1211    ///
1212    /// Call this periodically to prevent memory growth.
1213    pub fn cleanup_expired(&self) {
1214        let Ok(mut state) = self.state.write() else {
1215            return;
1216        };
1217
1218        // Remove expired authorization codes
1219        state.authorization_codes.retain(|_, c| !c.is_expired());
1220
1221        // Remove expired access tokens
1222        state.access_tokens.retain(|_, t| !t.is_expired());
1223
1224        // Remove expired refresh tokens
1225        state.refresh_tokens.retain(|_, t| !t.is_expired());
1226    }
1227
1228    /// Returns statistics about the server state.
1229    #[must_use]
1230    pub fn stats(&self) -> OAuthServerStats {
1231        let state = self.state.read().unwrap();
1232        OAuthServerStats {
1233            clients: state.clients.len(),
1234            authorization_codes: state.authorization_codes.len(),
1235            access_tokens: state.access_tokens.len(),
1236            refresh_tokens: state.refresh_tokens.len(),
1237            revoked_tokens: state.revoked_tokens.len(),
1238        }
1239    }
1240}
1241
1242/// Statistics about the OAuth server state.
1243#[derive(Debug, Clone, Default)]
1244pub struct OAuthServerStats {
1245    /// Number of registered clients.
1246    pub clients: usize,
1247    /// Number of pending authorization codes.
1248    pub authorization_codes: usize,
1249    /// Number of active access tokens.
1250    pub access_tokens: usize,
1251    /// Number of active refresh tokens.
1252    pub refresh_tokens: usize,
1253    /// Number of revoked tokens.
1254    pub revoked_tokens: usize,
1255}
1256
1257// =============================================================================
1258// Token Verifier Implementation
1259// =============================================================================
1260
1261/// OAuth token verifier for MCP integration.
1262///
1263/// This implements [`TokenVerifier`] to allow the OAuth server to be used
1264/// with the MCP server's [`TokenAuthProvider`].
1265pub struct OAuthTokenVerifier {
1266    server: Arc<OAuthServer>,
1267}
1268
1269impl TokenVerifier for OAuthTokenVerifier {
1270    fn verify(
1271        &self,
1272        _ctx: &McpContext,
1273        _request: AuthRequest<'_>,
1274        token: &AccessToken,
1275    ) -> McpResult<AuthContext> {
1276        // Only accept Bearer tokens
1277        if !token.scheme.eq_ignore_ascii_case("Bearer") {
1278            return Err(McpError::new(
1279                McpErrorCode::ResourceForbidden,
1280                "unsupported auth scheme",
1281            ));
1282        }
1283
1284        // Validate the token
1285        let token_info = self
1286            .server
1287            .validate_access_token(&token.token)
1288            .ok_or_else(|| {
1289                McpError::new(McpErrorCode::ResourceForbidden, "invalid or expired token")
1290            })?;
1291
1292        Ok(AuthContext {
1293            subject: token_info.subject,
1294            scopes: token_info.scopes,
1295            token: Some(token.clone()),
1296            claims: Some(serde_json::json!({
1297                "client_id": token_info.client_id,
1298                "iss": self.server.config.issuer,
1299                "iat": token_info.issued_at.elapsed().as_secs(),
1300            })),
1301        })
1302    }
1303}
1304
1305// =============================================================================
1306// Helper Functions
1307// =============================================================================
1308
1309/// Generates a cryptographically secure random token.
1310fn generate_token(bytes: usize) -> String {
1311    use std::collections::hash_map::RandomState;
1312    use std::hash::{BuildHasher, Hasher};
1313
1314    // Use system randomness via multiple hash iterations
1315    let mut result = Vec::with_capacity(bytes * 2);
1316    let state = RandomState::new();
1317
1318    for i in 0..bytes {
1319        let mut hasher = state.build_hasher();
1320        hasher.write_usize(i);
1321        hasher.write_u128(
1322            SystemTime::now()
1323                .duration_since(UNIX_EPOCH)
1324                .unwrap_or_default()
1325                .as_nanos(),
1326        );
1327        let hash = hasher.finish();
1328        result.extend_from_slice(&hash.to_le_bytes()[..2]);
1329    }
1330
1331    // Base64url encode (URL-safe, no padding)
1332    base64url_encode(&result[..bytes])
1333}
1334
1335/// Base64url encodes bytes (URL-safe, no padding).
1336fn base64url_encode(data: &[u8]) -> String {
1337    const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
1338
1339    let mut result = String::with_capacity((data.len() * 4).div_ceil(3));
1340    let mut i = 0;
1341
1342    while i + 2 < data.len() {
1343        let n = (u32::from(data[i]) << 16) | (u32::from(data[i + 1]) << 8) | u32::from(data[i + 2]);
1344        result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
1345        result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
1346        result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
1347        result.push(ALPHABET[n as usize & 0x3F] as char);
1348        i += 3;
1349    }
1350
1351    if i + 1 == data.len() {
1352        let n = u32::from(data[i]) << 16;
1353        result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
1354        result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
1355    } else if i + 2 == data.len() {
1356        let n = (u32::from(data[i]) << 16) | (u32::from(data[i + 1]) << 8);
1357        result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
1358        result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
1359        result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
1360    }
1361
1362    result
1363}
1364
1365/// Computes S256 code challenge from a verifier.
1366fn compute_s256_challenge(verifier: &str) -> String {
1367    // Simple SHA-256 implementation for PKCE
1368    // In production, use a proper crypto library
1369    let hash = simple_sha256(verifier.as_bytes());
1370    base64url_encode(&hash)
1371}
1372
1373/// Simple SHA-256 implementation (for PKCE code challenge).
1374/// Note: In production, use a proper cryptographic library.
1375fn simple_sha256(data: &[u8]) -> [u8; 32] {
1376    // This is a simplified hash for demonstration.
1377    // In a real implementation, use ring, sha2, or similar.
1378    use std::collections::hash_map::RandomState;
1379    use std::hash::{BuildHasher, Hasher};
1380
1381    let mut result = [0u8; 32];
1382    let state = RandomState::new();
1383
1384    for (i, chunk) in result.chunks_mut(8).enumerate() {
1385        let mut hasher = state.build_hasher();
1386        hasher.write(data);
1387        hasher.write_usize(i);
1388        let hash = hasher.finish().to_le_bytes();
1389        chunk.copy_from_slice(&hash[..chunk.len()]);
1390    }
1391
1392    result
1393}
1394
1395/// URL-encodes a string.
1396fn url_encode(s: &str) -> String {
1397    let mut result = String::with_capacity(s.len() * 3);
1398    for byte in s.bytes() {
1399        match byte {
1400            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1401                result.push(byte as char);
1402            }
1403            _ => {
1404                result.push('%');
1405                result.push_str(&format!("{:02X}", byte));
1406            }
1407        }
1408    }
1409    result
1410}
1411
1412/// Constant-time string comparison.
1413fn constant_time_eq(a: &str, b: &str) -> bool {
1414    if a.len() != b.len() {
1415        return false;
1416    }
1417
1418    let mut result = 0u8;
1419    for (x, y) in a.bytes().zip(b.bytes()) {
1420        result |= x ^ y;
1421    }
1422    result == 0
1423}
1424
1425/// Checks if a URI is a localhost redirect.
1426fn is_localhost_redirect(uri: &str) -> bool {
1427    uri.starts_with("http://localhost")
1428        || uri.starts_with("http://127.0.0.1")
1429        || uri.starts_with("http://[::1]")
1430}
1431
1432/// Checks if two localhost URIs match (ignoring port).
1433fn localhost_match(a: &str, b: &str) -> bool {
1434    // Extract scheme and path, ignore port
1435    fn extract_parts(uri: &str) -> Option<(String, String)> {
1436        let after_scheme = uri.strip_prefix("http://")?;
1437        // Find the host:port separator (first / or end of string)
1438        let path_start = after_scheme.find('/').unwrap_or(after_scheme.len());
1439        let host_port = &after_scheme[..path_start];
1440        let path = &after_scheme[path_start..];
1441
1442        // Extract host (remove port)
1443        let host = host_port.rsplit_once(':').map_or(host_port, |(h, _)| h);
1444        Some((host.to_string(), path.to_string()))
1445    }
1446
1447    match (extract_parts(a), extract_parts(b)) {
1448        (Some((host_a, path_a)), Some((host_b, path_b))) => {
1449            normalize_localhost(&host_a) == normalize_localhost(&host_b) && path_a == path_b
1450        }
1451        _ => false,
1452    }
1453}
1454
1455/// Normalizes localhost variants.
1456fn normalize_localhost(host: &str) -> &'static str {
1457    match host {
1458        "localhost" | "127.0.0.1" | "[::1]" => "localhost",
1459        _ => "other",
1460    }
1461}
1462
1463// =============================================================================
1464// Tests
1465// =============================================================================
1466
1467#[cfg(test)]
1468mod tests {
1469    use super::*;
1470
1471    #[test]
1472    fn test_client_builder() {
1473        let client = OAuthClient::builder("test-client")
1474            .redirect_uri("http://localhost:3000/callback")
1475            .scope("read")
1476            .scope("write")
1477            .name("Test Client")
1478            .build()
1479            .unwrap();
1480
1481        assert_eq!(client.client_id, "test-client");
1482        assert_eq!(client.client_type, ClientType::Public);
1483        assert_eq!(client.redirect_uris.len(), 1);
1484        assert!(client.allowed_scopes.contains("read"));
1485        assert!(client.allowed_scopes.contains("write"));
1486    }
1487
1488    #[test]
1489    fn test_confidential_client() {
1490        let client = OAuthClient::builder("test-client")
1491            .secret("super-secret")
1492            .redirect_uri("http://localhost:3000/callback")
1493            .build()
1494            .unwrap();
1495
1496        assert_eq!(client.client_type, ClientType::Confidential);
1497        assert!(client.authenticate(Some("super-secret")));
1498        assert!(!client.authenticate(Some("wrong-secret")));
1499        assert!(!client.authenticate(None));
1500    }
1501
1502    #[test]
1503    fn test_redirect_uri_validation() {
1504        let client = OAuthClient::builder("test-client")
1505            .redirect_uri("http://localhost:3000/callback")
1506            .redirect_uri("https://example.com/oauth/callback")
1507            .build()
1508            .unwrap();
1509
1510        // Exact match
1511        assert!(client.validate_redirect_uri("http://localhost:3000/callback"));
1512        assert!(client.validate_redirect_uri("https://example.com/oauth/callback"));
1513
1514        // Localhost with different port (allowed per RFC 8252)
1515        assert!(client.validate_redirect_uri("http://localhost:8080/callback"));
1516        assert!(client.validate_redirect_uri("http://127.0.0.1:9000/callback"));
1517
1518        // Invalid
1519        assert!(!client.validate_redirect_uri("http://localhost:3000/other"));
1520        assert!(!client.validate_redirect_uri("https://evil.com/callback"));
1521    }
1522
1523    #[test]
1524    fn test_scope_validation() {
1525        let client = OAuthClient::builder("test-client")
1526            .redirect_uri("http://localhost:3000/callback")
1527            .scope("read")
1528            .scope("write")
1529            .build()
1530            .unwrap();
1531
1532        assert!(client.validate_scopes(&["read".to_string()]));
1533        assert!(client.validate_scopes(&["read".to_string(), "write".to_string()]));
1534        assert!(!client.validate_scopes(&["admin".to_string()]));
1535    }
1536
1537    #[test]
1538    fn test_oauth_server_client_registration() {
1539        let server = OAuthServer::with_defaults();
1540
1541        let client = OAuthClient::builder("test-client")
1542            .redirect_uri("http://localhost:3000/callback")
1543            .build()
1544            .unwrap();
1545
1546        server.register_client(client).unwrap();
1547
1548        // Duplicate registration should fail
1549        let client2 = OAuthClient::builder("test-client")
1550            .redirect_uri("http://localhost:3000/callback")
1551            .build()
1552            .unwrap();
1553        assert!(server.register_client(client2).is_err());
1554
1555        // Verify client exists
1556        assert!(server.get_client("test-client").is_some());
1557        assert!(server.get_client("nonexistent").is_none());
1558    }
1559
1560    #[test]
1561    fn test_authorization_flow() {
1562        let server = OAuthServer::with_defaults();
1563
1564        let client = OAuthClient::builder("test-client")
1565            .redirect_uri("http://localhost:3000/callback")
1566            .scope("read")
1567            .build()
1568            .unwrap();
1569        server.register_client(client).unwrap();
1570
1571        // Create authorization request
1572        let request = AuthorizationRequest {
1573            response_type: "code".to_string(),
1574            client_id: "test-client".to_string(),
1575            redirect_uri: "http://localhost:3000/callback".to_string(),
1576            scopes: vec!["read".to_string()],
1577            state: Some("xyz".to_string()),
1578            code_challenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM".to_string(),
1579            code_challenge_method: CodeChallengeMethod::S256,
1580        };
1581
1582        let (code, redirect) = server
1583            .authorize(&request, Some("user123".to_string()))
1584            .unwrap();
1585
1586        assert!(!code.is_empty());
1587        assert!(redirect.contains("code="));
1588        assert!(redirect.contains("state=xyz"));
1589    }
1590
1591    #[test]
1592    fn test_pkce_required() {
1593        let server = OAuthServer::with_defaults();
1594
1595        let client = OAuthClient::builder("test-client")
1596            .redirect_uri("http://localhost:3000/callback")
1597            .build()
1598            .unwrap();
1599        server.register_client(client).unwrap();
1600
1601        // Request without PKCE should fail
1602        let request = AuthorizationRequest {
1603            response_type: "code".to_string(),
1604            client_id: "test-client".to_string(),
1605            redirect_uri: "http://localhost:3000/callback".to_string(),
1606            scopes: vec![],
1607            state: None,
1608            code_challenge: String::new(), // Missing!
1609            code_challenge_method: CodeChallengeMethod::S256,
1610        };
1611
1612        let result = server.authorize(&request, None);
1613        assert!(matches!(result, Err(OAuthError::InvalidRequest(_))));
1614    }
1615
1616    #[test]
1617    fn test_token_generation() {
1618        let token1 = generate_token(32);
1619        let token2 = generate_token(32);
1620
1621        // Tokens should be unique
1622        assert_ne!(token1, token2);
1623        // Tokens should be URL-safe
1624        assert!(
1625            token1
1626                .chars()
1627                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
1628        );
1629    }
1630
1631    #[test]
1632    fn test_base64url_encode() {
1633        // Test vectors from RFC 4648
1634        assert_eq!(base64url_encode(b""), "");
1635        assert_eq!(base64url_encode(b"f"), "Zg");
1636        assert_eq!(base64url_encode(b"fo"), "Zm8");
1637        assert_eq!(base64url_encode(b"foo"), "Zm9v");
1638        assert_eq!(base64url_encode(b"foob"), "Zm9vYg");
1639        assert_eq!(base64url_encode(b"fooba"), "Zm9vYmE");
1640        assert_eq!(base64url_encode(b"foobar"), "Zm9vYmFy");
1641    }
1642
1643    #[test]
1644    fn test_url_encode() {
1645        assert_eq!(url_encode("hello"), "hello");
1646        assert_eq!(url_encode("hello world"), "hello%20world");
1647        assert_eq!(url_encode("a=b&c=d"), "a%3Db%26c%3Dd");
1648    }
1649
1650    #[test]
1651    fn test_constant_time_eq() {
1652        assert!(constant_time_eq("hello", "hello"));
1653        assert!(!constant_time_eq("hello", "world"));
1654        assert!(!constant_time_eq("hello", "hell"));
1655    }
1656
1657    #[test]
1658    fn test_localhost_match() {
1659        assert!(localhost_match(
1660            "http://localhost:3000/callback",
1661            "http://localhost:8080/callback"
1662        ));
1663        assert!(localhost_match(
1664            "http://127.0.0.1:3000/callback",
1665            "http://localhost:8080/callback"
1666        ));
1667        assert!(!localhost_match(
1668            "http://localhost:3000/callback",
1669            "http://localhost:3000/other"
1670        ));
1671    }
1672
1673    #[test]
1674    fn test_oauth_server_stats() {
1675        let server = OAuthServer::with_defaults();
1676
1677        let stats = server.stats();
1678        assert_eq!(stats.clients, 0);
1679        assert_eq!(stats.access_tokens, 0);
1680
1681        let client = OAuthClient::builder("test-client")
1682            .redirect_uri("http://localhost:3000/callback")
1683            .build()
1684            .unwrap();
1685        server.register_client(client).unwrap();
1686
1687        let stats = server.stats();
1688        assert_eq!(stats.clients, 1);
1689    }
1690
1691    #[test]
1692    fn test_code_challenge_method_parse() {
1693        assert_eq!(
1694            CodeChallengeMethod::parse("plain"),
1695            Some(CodeChallengeMethod::Plain)
1696        );
1697        assert_eq!(
1698            CodeChallengeMethod::parse("S256"),
1699            Some(CodeChallengeMethod::S256)
1700        );
1701        assert_eq!(CodeChallengeMethod::parse("unknown"), None);
1702    }
1703
1704    #[test]
1705    fn test_oauth_error_display() {
1706        let err = OAuthError::InvalidRequest("missing parameter".to_string());
1707        assert_eq!(err.error_code(), "invalid_request");
1708        assert_eq!(err.description(), "missing parameter");
1709        assert_eq!(err.to_string(), "invalid_request: missing parameter");
1710    }
1711
1712    #[test]
1713    fn test_token_revocation() {
1714        let server = Arc::new(OAuthServer::with_defaults());
1715
1716        // Register a client
1717        let client = OAuthClient::builder("test-client")
1718            .redirect_uri("http://localhost:3000/callback")
1719            .scope("read")
1720            .build()
1721            .unwrap();
1722        server.register_client(client).unwrap();
1723
1724        // Manually create a token for testing
1725        let token_response = {
1726            let mut state = server.state.write().unwrap();
1727            let now = Instant::now();
1728            let token = OAuthToken {
1729                token: "test-access-token".to_string(),
1730                token_type: TokenType::Bearer,
1731                client_id: "test-client".to_string(),
1732                scopes: vec!["read".to_string()],
1733                issued_at: now,
1734                expires_at: now + Duration::from_secs(3600),
1735                subject: Some("user123".to_string()),
1736                is_refresh_token: false,
1737            };
1738            state
1739                .access_tokens
1740                .insert("test-access-token".to_string(), token);
1741            TokenResponse {
1742                access_token: "test-access-token".to_string(),
1743                token_type: "bearer".to_string(),
1744                expires_in: 3600,
1745                refresh_token: None,
1746                scope: Some("read".to_string()),
1747            }
1748        };
1749
1750        // Token should be valid
1751        assert!(
1752            server
1753                .validate_access_token(&token_response.access_token)
1754                .is_some()
1755        );
1756
1757        // Revoke the token
1758        server
1759            .revoke(&token_response.access_token, "test-client", None)
1760            .unwrap();
1761
1762        // Token should no longer be valid
1763        assert!(
1764            server
1765                .validate_access_token(&token_response.access_token)
1766                .is_none()
1767        );
1768    }
1769
1770    #[test]
1771    fn test_client_unregistration() {
1772        let server = OAuthServer::with_defaults();
1773
1774        let client = OAuthClient::builder("test-client")
1775            .redirect_uri("http://localhost:3000/callback")
1776            .build()
1777            .unwrap();
1778        server.register_client(client).unwrap();
1779
1780        assert!(server.get_client("test-client").is_some());
1781
1782        server.unregister_client("test-client").unwrap();
1783
1784        assert!(server.get_client("test-client").is_none());
1785
1786        // Unregistering again should fail
1787        assert!(server.unregister_client("test-client").is_err());
1788    }
1789
1790    #[test]
1791    fn test_token_verifier() {
1792        let server = Arc::new(OAuthServer::with_defaults());
1793
1794        // Register a client and create a token
1795        let client = OAuthClient::builder("test-client")
1796            .redirect_uri("http://localhost:3000/callback")
1797            .scope("read")
1798            .build()
1799            .unwrap();
1800        server.register_client(client).unwrap();
1801
1802        // Create a token manually
1803        {
1804            let mut state = server.state.write().unwrap();
1805            let now = Instant::now();
1806            let token = OAuthToken {
1807                token: "valid-token".to_string(),
1808                token_type: TokenType::Bearer,
1809                client_id: "test-client".to_string(),
1810                scopes: vec!["read".to_string()],
1811                issued_at: now,
1812                expires_at: now + Duration::from_secs(3600),
1813                subject: Some("user123".to_string()),
1814                is_refresh_token: false,
1815            };
1816            state.access_tokens.insert("valid-token".to_string(), token);
1817        }
1818
1819        // Create verifier
1820        let verifier = server.token_verifier();
1821        let cx = asupersync::Cx::for_testing();
1822        let mcp_ctx = McpContext::new(cx, 1);
1823        let auth_request = AuthRequest {
1824            method: "test",
1825            params: None,
1826            request_id: 1,
1827        };
1828
1829        // Valid token
1830        let access = AccessToken {
1831            scheme: "Bearer".to_string(),
1832            token: "valid-token".to_string(),
1833        };
1834        let result = verifier.verify(&mcp_ctx, auth_request, &access);
1835        assert!(result.is_ok());
1836        let auth = result.unwrap();
1837        assert_eq!(auth.subject, Some("user123".to_string()));
1838        assert_eq!(auth.scopes, vec!["read".to_string()]);
1839
1840        // Invalid token
1841        let invalid = AccessToken {
1842            scheme: "Bearer".to_string(),
1843            token: "invalid-token".to_string(),
1844        };
1845        let result = verifier.verify(&mcp_ctx, auth_request, &invalid);
1846        assert!(result.is_err());
1847
1848        // Wrong scheme
1849        let wrong_scheme = AccessToken {
1850            scheme: "Basic".to_string(),
1851            token: "valid-token".to_string(),
1852        };
1853        let result = verifier.verify(&mcp_ctx, auth_request, &wrong_scheme);
1854        assert!(result.is_err());
1855    }
1856}