pmcp/server/auth/
oauth2.rs

1//! OAuth 2.0 server implementation for MCP.
2
3use crate::error::{Error, ErrorCode, Result};
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12/// OAuth 2.0 grant types.
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub enum GrantType {
15    /// Authorization code grant type.
16    #[serde(rename = "authorization_code")]
17    AuthorizationCode,
18    /// Refresh token grant type.
19    #[serde(rename = "refresh_token")]
20    RefreshToken,
21    /// Client credentials grant type.
22    #[serde(rename = "client_credentials")]
23    ClientCredentials,
24    /// Resource owner password credentials grant type.
25    #[serde(rename = "password")]
26    Password,
27}
28
29/// OAuth 2.0 response types.
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub enum ResponseType {
32    /// Authorization code response type.
33    #[serde(rename = "code")]
34    Code,
35    /// Implicit token response type.
36    #[serde(rename = "token")]
37    Token,
38}
39
40/// OAuth 2.0 token types.
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum TokenType {
44    /// Bearer token type.
45    Bearer,
46}
47
48/// OAuth 2.0 client registration.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct OAuthClient {
51    /// Client identifier.
52    pub client_id: String,
53
54    /// Client secret (confidential clients only).
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub client_secret: Option<String>,
57
58    /// Client name.
59    pub client_name: String,
60
61    /// Redirect URIs.
62    pub redirect_uris: Vec<String>,
63
64    /// Allowed grant types.
65    pub grant_types: Vec<GrantType>,
66
67    /// Allowed response types.
68    pub response_types: Vec<ResponseType>,
69
70    /// Allowed scopes.
71    pub scopes: Vec<String>,
72
73    /// Client metadata.
74    #[serde(flatten)]
75    pub metadata: HashMap<String, serde_json::Value>,
76}
77
78/// OAuth 2.0 authorization code.
79#[derive(Debug, Clone)]
80pub struct AuthorizationCode {
81    /// Authorization code value.
82    pub code: String,
83
84    /// Client ID this code was issued to.
85    pub client_id: String,
86
87    /// User ID this code was issued for.
88    pub user_id: String,
89
90    /// Redirect URI used in authorization request.
91    pub redirect_uri: String,
92
93    /// Requested scopes.
94    pub scopes: Vec<String>,
95
96    /// PKCE code challenge if used.
97    pub code_challenge: Option<String>,
98
99    /// PKCE challenge method.
100    pub code_challenge_method: Option<String>,
101
102    /// Expiration timestamp.
103    pub expires_at: u64,
104}
105
106/// OAuth 2.0 access token.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct AccessToken {
109    /// Token value.
110    pub access_token: String,
111
112    /// Token type (always "bearer").
113    pub token_type: TokenType,
114
115    /// Expiration time in seconds.
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub expires_in: Option<u64>,
118
119    /// Refresh token.
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub refresh_token: Option<String>,
122
123    /// Granted scopes.
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub scope: Option<String>,
126
127    /// Additional token metadata.
128    #[serde(flatten)]
129    pub extra: HashMap<String, serde_json::Value>,
130}
131
132/// OAuth 2.0 token info for introspection.
133#[derive(Debug, Clone)]
134pub struct TokenInfo {
135    /// Token value.
136    pub token: String,
137
138    /// Client ID.
139    pub client_id: String,
140
141    /// User ID.
142    pub user_id: String,
143
144    /// Granted scopes.
145    pub scopes: Vec<String>,
146
147    /// Expiration timestamp.
148    pub expires_at: u64,
149
150    /// Token type.
151    pub token_type: TokenType,
152}
153
154/// OAuth 2.0 error response.
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct OAuthError {
157    /// Error code.
158    pub error: String,
159
160    /// Error description.
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub error_description: Option<String>,
163
164    /// Error URI.
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub error_uri: Option<String>,
167}
168
169/// `OpenID Connect Discovery` metadata.
170/// Represents the well-known configuration for OAuth 2.0/OIDC servers.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct OidcDiscoveryMetadata {
173    /// Issuer identifier.
174    pub issuer: String,
175
176    /// Authorization endpoint URL.
177    pub authorization_endpoint: String,
178
179    /// Token endpoint URL.
180    pub token_endpoint: String,
181
182    /// JWKS (JSON Web Key Set) URI.
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub jwks_uri: Option<String>,
185
186    /// User info endpoint URL.
187    #[serde(skip_serializing_if = "Option::is_none")]
188    pub userinfo_endpoint: Option<String>,
189
190    /// Registration endpoint.
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub registration_endpoint: Option<String>,
193
194    /// Revocation endpoint URL.
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub revocation_endpoint: Option<String>,
197
198    /// Introspection endpoint URL.
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub introspection_endpoint: Option<String>,
201
202    /// Supported response types.
203    pub response_types_supported: Vec<ResponseType>,
204
205    /// Supported grant types.
206    pub grant_types_supported: Vec<GrantType>,
207
208    /// Supported scopes.
209    pub scopes_supported: Vec<String>,
210
211    /// Supported token endpoint auth methods.
212    pub token_endpoint_auth_methods_supported: Vec<String>,
213
214    /// Supported PKCE code challenge methods.
215    pub code_challenge_methods_supported: Vec<String>,
216}
217
218/// OAuth 2.0 server metadata (alias for backward compatibility).
219pub type OAuthMetadata = OidcDiscoveryMetadata;
220
221/// OAuth 2.0 authorization request.
222#[derive(Debug, Clone, Deserialize)]
223pub struct AuthorizationRequest {
224    /// Response type.
225    pub response_type: ResponseType,
226
227    /// Client ID.
228    pub client_id: String,
229
230    /// Redirect URI.
231    pub redirect_uri: String,
232
233    /// Requested scope.
234    #[serde(default)]
235    pub scope: String,
236
237    /// State parameter.
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub state: Option<String>,
240
241    /// PKCE code challenge.
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub code_challenge: Option<String>,
244
245    /// PKCE challenge method.
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub code_challenge_method: Option<String>,
248}
249
250/// OAuth 2.0 token request.
251#[derive(Debug, Clone, Deserialize)]
252pub struct TokenRequest {
253    /// Grant type.
254    pub grant_type: GrantType,
255
256    /// Authorization code (for `authorization_code` grant).
257    #[serde(skip_serializing_if = "Option::is_none")]
258    pub code: Option<String>,
259
260    /// Redirect URI (for `authorization_code` grant).
261    #[serde(skip_serializing_if = "Option::is_none")]
262    pub redirect_uri: Option<String>,
263
264    /// Client ID.
265    #[serde(skip_serializing_if = "Option::is_none")]
266    pub client_id: Option<String>,
267
268    /// Client secret.
269    #[serde(skip_serializing_if = "Option::is_none")]
270    pub client_secret: Option<String>,
271
272    /// Refresh token (for `refresh_token` grant).
273    #[serde(skip_serializing_if = "Option::is_none")]
274    pub refresh_token: Option<String>,
275
276    /// Username (for password grant).
277    #[serde(skip_serializing_if = "Option::is_none")]
278    pub username: Option<String>,
279
280    /// Password (for password grant).
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub password: Option<String>,
283
284    /// Requested scope.
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub scope: Option<String>,
287
288    /// PKCE code verifier.
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub code_verifier: Option<String>,
291}
292
293/// OAuth 2.0 revocation request.
294#[derive(Debug, Clone, Deserialize)]
295pub struct RevocationRequest {
296    /// Token to revoke.
297    pub token: String,
298
299    /// Token type hint.
300    #[serde(skip_serializing_if = "Option::is_none")]
301    pub token_type_hint: Option<String>,
302
303    /// Client ID.
304    #[serde(skip_serializing_if = "Option::is_none")]
305    pub client_id: Option<String>,
306
307    /// Client secret.
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub client_secret: Option<String>,
310}
311
312/// OAuth 2.0 server provider trait.
313#[async_trait]
314pub trait OAuthProvider: Send + Sync {
315    /// Register a new client.
316    async fn register_client(&self, client: OAuthClient) -> Result<OAuthClient>;
317
318    /// Get client by ID.
319    async fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>>;
320
321    /// Validate authorization request.
322    async fn validate_authorization(&self, request: &AuthorizationRequest) -> Result<()>;
323
324    /// Create authorization code.
325    async fn create_authorization_code(
326        &self,
327        client_id: &str,
328        user_id: &str,
329        redirect_uri: &str,
330        scopes: Vec<String>,
331        code_challenge: Option<String>,
332        code_challenge_method: Option<String>,
333    ) -> Result<String>;
334
335    /// Exchange authorization code for token.
336    async fn exchange_code(&self, request: &TokenRequest) -> Result<AccessToken>;
337
338    /// Create access token.
339    async fn create_access_token(
340        &self,
341        client_id: &str,
342        user_id: &str,
343        scopes: Vec<String>,
344    ) -> Result<AccessToken>;
345
346    /// Refresh access token.
347    async fn refresh_token(&self, refresh_token: &str) -> Result<AccessToken>;
348
349    /// Revoke token.
350    async fn revoke_token(&self, token: &str) -> Result<()>;
351
352    /// Validate access token.
353    async fn validate_token(&self, token: &str) -> Result<TokenInfo>;
354
355    /// Get server metadata.
356    async fn metadata(&self) -> Result<OAuthMetadata>;
357
358    /// Discover OIDC configuration from well-known endpoint.
359    /// Returns the discovery metadata if successful.
360    /// Implementations should handle retries for network failures.
361    async fn discover(&self, _issuer_url: &str) -> Result<OidcDiscoveryMetadata> {
362        // Default implementation that would fetch from .well-known/openid-configuration
363        // For now, return an error indicating it needs implementation
364        Err(Error::protocol(
365            ErrorCode::METHOD_NOT_FOUND,
366            "OIDC discovery not implemented for this provider",
367        ))
368    }
369}
370
371/// In-memory OAuth 2.0 provider implementation.
372#[derive(Debug)]
373pub struct InMemoryOAuthProvider {
374    /// Base URL for endpoints.
375    base_url: String,
376
377    /// Registered clients.
378    clients: Arc<RwLock<HashMap<String, OAuthClient>>>,
379
380    /// Active authorization codes.
381    codes: Arc<RwLock<HashMap<String, AuthorizationCode>>>,
382
383    /// Active access tokens.
384    tokens: Arc<RwLock<HashMap<String, TokenInfo>>>,
385
386    /// Refresh tokens.
387    refresh_tokens: Arc<RwLock<HashMap<String, String>>>,
388
389    /// Token expiration time in seconds.
390    token_expiration: u64,
391
392    /// Code expiration time in seconds.
393    code_expiration: u64,
394
395    /// Supported scopes.
396    supported_scopes: Vec<String>,
397}
398
399impl InMemoryOAuthProvider {
400    /// Create a new in-memory OAuth provider.
401    pub fn new(base_url: impl Into<String>) -> Self {
402        Self {
403            base_url: base_url.into(),
404            clients: Arc::new(RwLock::new(HashMap::new())),
405            codes: Arc::new(RwLock::new(HashMap::new())),
406            tokens: Arc::new(RwLock::new(HashMap::new())),
407            refresh_tokens: Arc::new(RwLock::new(HashMap::new())),
408            token_expiration: 3600, // 1 hour
409            code_expiration: 600,   // 10 minutes
410            supported_scopes: vec!["read".to_string(), "write".to_string()],
411        }
412    }
413
414    /// Generate a secure random token.
415    fn generate_token() -> String {
416        Uuid::new_v4().to_string()
417    }
418
419    /// Get current timestamp.
420    fn now() -> u64 {
421        std::time::SystemTime::now()
422            .duration_since(std::time::UNIX_EPOCH)
423            .unwrap()
424            .as_secs()
425    }
426
427    /// Verify PKCE code challenge.
428    fn verify_pkce(verifier: &str, challenge: &str, method: &str) -> bool {
429        match method {
430            "plain" => verifier == challenge,
431            "S256" => {
432                use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
433                use sha2::{Digest, Sha256};
434
435                let mut hasher = Sha256::new();
436                hasher.update(verifier.as_bytes());
437                let result = hasher.finalize();
438                let encoded = URL_SAFE_NO_PAD.encode(result);
439                encoded == challenge
440            },
441            _ => false,
442        }
443    }
444}
445
446#[async_trait]
447impl OAuthProvider for InMemoryOAuthProvider {
448    async fn register_client(&self, mut client: OAuthClient) -> Result<OAuthClient> {
449        // Generate client credentials if not provided
450        if client.client_id.is_empty() {
451            client.client_id = Self::generate_token();
452        }
453        if client.client_secret.is_none() {
454            client.client_secret = Some(Self::generate_token());
455        }
456
457        // Store client
458        let mut clients = self.clients.write().await;
459        clients.insert(client.client_id.clone(), client.clone());
460
461        Ok(client)
462    }
463
464    async fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>> {
465        let clients = self.clients.read().await;
466        Ok(clients.get(client_id).cloned())
467    }
468
469    async fn validate_authorization(&self, request: &AuthorizationRequest) -> Result<()> {
470        // Get client
471        let client = self
472            .get_client(&request.client_id)
473            .await?
474            .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid client_id"))?;
475
476        // Validate redirect URI
477        if !client.redirect_uris.contains(&request.redirect_uri) {
478            return Err(Error::protocol(
479                ErrorCode::INVALID_REQUEST,
480                "Invalid redirect_uri",
481            ));
482        }
483
484        // Validate response type
485        if !client.response_types.contains(&request.response_type) {
486            return Err(Error::protocol(
487                ErrorCode::INVALID_REQUEST,
488                "Unsupported response_type",
489            ));
490        }
491
492        // Validate scopes
493        let requested_scopes: Vec<&str> = request.scope.split_whitespace().collect();
494        for scope in &requested_scopes {
495            if !self.supported_scopes.iter().any(|s| s == scope) {
496                return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid scope"));
497            }
498        }
499
500        Ok(())
501    }
502
503    async fn create_authorization_code(
504        &self,
505        client_id: &str,
506        user_id: &str,
507        redirect_uri: &str,
508        scopes: Vec<String>,
509        code_challenge: Option<String>,
510        code_challenge_method: Option<String>,
511    ) -> Result<String> {
512        let code = Self::generate_token();
513        let expires_at = Self::now() + self.code_expiration;
514
515        let auth_code = AuthorizationCode {
516            code: code.clone(),
517            client_id: client_id.to_string(),
518            user_id: user_id.to_string(),
519            redirect_uri: redirect_uri.to_string(),
520            scopes,
521            code_challenge,
522            code_challenge_method,
523            expires_at,
524        };
525
526        let mut codes = self.codes.write().await;
527        codes.insert(code.clone(), auth_code);
528
529        Ok(code)
530    }
531
532    async fn exchange_code(&self, request: &TokenRequest) -> Result<AccessToken> {
533        // Validate grant type
534        if request.grant_type != GrantType::AuthorizationCode {
535            return Err(Error::protocol(
536                ErrorCode::INVALID_REQUEST,
537                "Invalid grant_type",
538            ));
539        }
540
541        // Get code
542        let code = request
543            .code
544            .as_ref()
545            .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing code"))?;
546
547        let mut codes = self.codes.write().await;
548        let auth_code = codes
549            .remove(code)
550            .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid code"))?;
551
552        // Check expiration
553        if auth_code.expires_at < Self::now() {
554            return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Code expired"));
555        }
556
557        // Validate client
558        let client_id = request
559            .client_id
560            .as_ref()
561            .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing client_id"))?;
562
563        if auth_code.client_id != *client_id {
564            return Err(Error::protocol(
565                ErrorCode::INVALID_REQUEST,
566                "Invalid client_id",
567            ));
568        }
569
570        // Validate redirect URI
571        let redirect_uri = request
572            .redirect_uri
573            .as_ref()
574            .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Missing redirect_uri"))?;
575
576        if auth_code.redirect_uri != *redirect_uri {
577            return Err(Error::protocol(
578                ErrorCode::INVALID_REQUEST,
579                "Invalid redirect_uri",
580            ));
581        }
582
583        // Verify PKCE if used
584        if let Some(challenge) = &auth_code.code_challenge {
585            let verifier = request.code_verifier.as_ref().ok_or_else(|| {
586                Error::protocol(ErrorCode::INVALID_REQUEST, "Missing code_verifier")
587            })?;
588
589            let method = auth_code
590                .code_challenge_method
591                .as_deref()
592                .unwrap_or("plain");
593            if !Self::verify_pkce(verifier, challenge, method) {
594                return Err(Error::protocol(
595                    ErrorCode::INVALID_REQUEST,
596                    "Invalid code_verifier",
597                ));
598            }
599        }
600
601        // Create access token
602        self.create_access_token(&auth_code.client_id, &auth_code.user_id, auth_code.scopes)
603            .await
604    }
605
606    async fn create_access_token(
607        &self,
608        client_id: &str,
609        user_id: &str,
610        scopes: Vec<String>,
611    ) -> Result<AccessToken> {
612        let access_token = Self::generate_token();
613        let refresh_token = Self::generate_token();
614        let expires_at = Self::now() + self.token_expiration;
615
616        // Store token info
617        let token_info = TokenInfo {
618            token: access_token.clone(),
619            client_id: client_id.to_string(),
620            user_id: user_id.to_string(),
621            scopes: scopes.clone(),
622            expires_at,
623            token_type: TokenType::Bearer,
624        };
625
626        let mut tokens = self.tokens.write().await;
627        tokens.insert(access_token.clone(), token_info);
628
629        // Store refresh token mapping
630        let mut refresh_tokens = self.refresh_tokens.write().await;
631        refresh_tokens.insert(refresh_token.clone(), access_token.clone());
632
633        Ok(AccessToken {
634            access_token,
635            token_type: TokenType::Bearer,
636            expires_in: Some(self.token_expiration),
637            refresh_token: Some(refresh_token),
638            scope: Some(scopes.join(" ")),
639            extra: HashMap::new(),
640        })
641    }
642
643    async fn refresh_token(&self, refresh_token: &str) -> Result<AccessToken> {
644        // Get associated access token
645        let refresh_tokens = self.refresh_tokens.read().await;
646        let old_token = refresh_tokens
647            .get(refresh_token)
648            .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid refresh_token"))?
649            .clone();
650
651        // Get token info
652        let tokens = self.tokens.read().await;
653        let token_info = tokens
654            .get(&old_token)
655            .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid refresh_token"))?;
656
657        let client_id = token_info.client_id.clone();
658        let user_id = token_info.user_id.clone();
659        let scopes = token_info.scopes.clone();
660
661        drop(tokens);
662        drop(refresh_tokens);
663
664        // Remove old tokens
665        let mut tokens = self.tokens.write().await;
666        tokens.remove(&old_token);
667        drop(tokens);
668
669        let mut refresh_tokens = self.refresh_tokens.write().await;
670        refresh_tokens.remove(refresh_token);
671        drop(refresh_tokens);
672
673        // Create new token
674        self.create_access_token(&client_id, &user_id, scopes).await
675    }
676
677    async fn revoke_token(&self, token: &str) -> Result<()> {
678        // Try to revoke as access token
679        let mut tokens = self.tokens.write().await;
680        if tokens.remove(token).is_some() {
681            return Ok(());
682        }
683        drop(tokens);
684
685        // Try to revoke as refresh token
686        let mut refresh_tokens = self.refresh_tokens.write().await;
687        if let Some(access_token) = refresh_tokens.remove(token) {
688            let mut tokens = self.tokens.write().await;
689            tokens.remove(&access_token);
690        }
691
692        Ok(())
693    }
694
695    async fn validate_token(&self, token: &str) -> Result<TokenInfo> {
696        let tokens = self.tokens.read().await;
697        let token_info = tokens
698            .get(token)
699            .ok_or_else(|| Error::protocol(ErrorCode::INVALID_REQUEST, "Invalid token"))?;
700
701        // Check expiration
702        if token_info.expires_at < Self::now() {
703            return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Token expired"));
704        }
705
706        Ok(token_info.clone())
707    }
708
709    async fn metadata(&self) -> Result<OAuthMetadata> {
710        Ok(OAuthMetadata {
711            issuer: self.base_url.clone(),
712            authorization_endpoint: format!("{}/oauth2/authorize", self.base_url),
713            token_endpoint: format!("{}/oauth2/token", self.base_url),
714            jwks_uri: Some(format!("{}/oauth2/jwks", self.base_url)),
715            userinfo_endpoint: Some(format!("{}/oauth2/userinfo", self.base_url)),
716            registration_endpoint: Some(format!("{}/oauth2/register", self.base_url)),
717            revocation_endpoint: Some(format!("{}/oauth2/revoke", self.base_url)),
718            introspection_endpoint: Some(format!("{}/oauth2/introspect", self.base_url)),
719            response_types_supported: vec![ResponseType::Code],
720            grant_types_supported: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
721            scopes_supported: self.supported_scopes.clone(),
722            token_endpoint_auth_methods_supported: vec![
723                "client_secret_basic".to_string(),
724                "client_secret_post".to_string(),
725            ],
726            code_challenge_methods_supported: vec!["plain".to_string(), "S256".to_string()],
727        })
728    }
729}
730
731/// Proxy OAuth provider that delegates to an upstream OAuth server.
732#[derive(Debug)]
733pub struct ProxyOAuthProvider {
734    /// Upstream OAuth server URL.
735    _upstream_url: String,
736
737    /// Local token cache.
738    _token_cache: Arc<RwLock<HashMap<String, TokenInfo>>>,
739}
740
741impl ProxyOAuthProvider {
742    /// Create a new proxy OAuth provider.
743    pub fn new(upstream_url: impl Into<String>) -> Self {
744        Self {
745            _upstream_url: upstream_url.into(),
746            _token_cache: Arc::new(RwLock::new(HashMap::new())),
747        }
748    }
749}
750
751// Note: ProxyOAuthProvider implementation would require HTTP client functionality
752// which is beyond the scope of this initial implementation
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757
758    #[tokio::test]
759    async fn test_oauth_flow() {
760        let provider = InMemoryOAuthProvider::new("http://localhost:8080");
761
762        // Register client
763        let client = OAuthClient {
764            client_id: String::new(),
765            client_secret: None,
766            client_name: "Test Client".to_string(),
767            redirect_uris: vec!["http://localhost:3000/callback".to_string()],
768            grant_types: vec![GrantType::AuthorizationCode],
769            response_types: vec![ResponseType::Code],
770            scopes: vec!["read".to_string(), "write".to_string()],
771            metadata: HashMap::new(),
772        };
773
774        let registered = provider.register_client(client).await.unwrap();
775        assert!(!registered.client_id.is_empty());
776        assert!(registered.client_secret.is_some());
777
778        // Validate authorization request
779        let auth_req = AuthorizationRequest {
780            response_type: ResponseType::Code,
781            client_id: registered.client_id.clone(),
782            redirect_uri: "http://localhost:3000/callback".to_string(),
783            scope: "read write".to_string(),
784            state: Some("test-state".to_string()),
785            code_challenge: None,
786            code_challenge_method: None,
787        };
788
789        provider.validate_authorization(&auth_req).await.unwrap();
790
791        // Create authorization code
792        let code = provider
793            .create_authorization_code(
794                &registered.client_id,
795                "user-123",
796                &auth_req.redirect_uri,
797                vec!["read".to_string(), "write".to_string()],
798                None,
799                None,
800            )
801            .await
802            .unwrap();
803
804        // Exchange code for token
805        let token_req = TokenRequest {
806            grant_type: GrantType::AuthorizationCode,
807            code: Some(code),
808            redirect_uri: Some(auth_req.redirect_uri),
809            client_id: Some(registered.client_id.clone()),
810            client_secret: registered.client_secret.clone(),
811            refresh_token: None,
812            username: None,
813            password: None,
814            scope: None,
815            code_verifier: None,
816        };
817
818        let token = provider.exchange_code(&token_req).await.unwrap();
819        assert_eq!(token.token_type, TokenType::Bearer);
820        assert!(token.refresh_token.is_some());
821
822        // Validate token
823        let token_info = provider.validate_token(&token.access_token).await.unwrap();
824        assert_eq!(token_info.client_id, registered.client_id);
825        assert_eq!(token_info.user_id, "user-123");
826    }
827}