ipfrs_interface/
oauth2.rs

1//! OAuth2 Authentication Module
2//!
3//! Implements OAuth2 authentication flows including:
4//! - Authorization Code Flow (with PKCE)
5//! - Client Credentials Flow
6//! - Refresh Token Flow
7//!
8//! Supports integration with external OAuth2 providers (Google, GitHub, etc.)
9//! and can also act as an OAuth2 authorization server.
10
11use crate::auth::{AuthError, AuthResult, JwtManager};
12use dashmap::DashMap;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16use uuid::Uuid;
17
18/// OAuth2 grant types
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum GrantType {
22    /// Authorization Code Flow
23    AuthorizationCode,
24    /// Client Credentials Flow
25    ClientCredentials,
26    /// Refresh Token Flow
27    RefreshToken,
28    /// Implicit Flow (deprecated, but included for compatibility)
29    Implicit,
30}
31
32/// OAuth2 token type
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
34#[serde(rename_all = "snake_case")]
35pub enum TokenType {
36    Bearer,
37}
38
39/// OAuth2 response type
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
41#[serde(rename_all = "snake_case")]
42pub enum ResponseType {
43    Code,
44    Token,
45}
46
47/// OAuth2 scope
48#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub struct Scope(String);
50
51impl Scope {
52    pub fn new(scope: impl Into<String>) -> Self {
53        Self(scope.into())
54    }
55
56    pub fn as_str(&self) -> &str {
57        &self.0
58    }
59
60    /// Parse space-separated scopes
61    pub fn parse_scopes(scopes: &str) -> Vec<Scope> {
62        scopes.split_whitespace().map(Scope::new).collect()
63    }
64
65    /// Join scopes into space-separated string
66    pub fn join_scopes(scopes: &[Scope]) -> String {
67        scopes
68            .iter()
69            .map(|s| s.as_str())
70            .collect::<Vec<_>>()
71            .join(" ")
72    }
73}
74
75/// OAuth2 client registration
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct OAuth2Client {
78    pub client_id: String,
79    pub client_secret: String,
80    pub redirect_uris: Vec<String>,
81    pub grant_types: Vec<GrantType>,
82    pub scopes: Vec<Scope>,
83    pub name: String,
84    pub created_at: u64,
85}
86
87impl OAuth2Client {
88    pub fn new(
89        name: String,
90        redirect_uris: Vec<String>,
91        grant_types: Vec<GrantType>,
92        scopes: Vec<Scope>,
93    ) -> Self {
94        Self {
95            client_id: Uuid::new_v4().to_string(),
96            client_secret: Uuid::new_v4().to_string(),
97            redirect_uris,
98            grant_types,
99            scopes,
100            name,
101            created_at: SystemTime::now()
102                .duration_since(UNIX_EPOCH)
103                .unwrap()
104                .as_secs(),
105        }
106    }
107
108    /// Verify client secret
109    pub fn verify_secret(&self, secret: &str) -> bool {
110        self.client_secret == secret
111    }
112
113    /// Check if redirect URI is allowed
114    pub fn is_redirect_uri_allowed(&self, uri: &str) -> bool {
115        self.redirect_uris.iter().any(|u| u == uri)
116    }
117
118    /// Check if grant type is allowed
119    pub fn is_grant_type_allowed(&self, grant_type: GrantType) -> bool {
120        self.grant_types.contains(&grant_type)
121    }
122
123    /// Check if scope is allowed
124    pub fn is_scope_allowed(&self, scope: &Scope) -> bool {
125        self.scopes.contains(scope)
126    }
127}
128
129/// Authorization code
130#[derive(Debug, Clone)]
131pub struct AuthorizationCode {
132    pub code: String,
133    pub client_id: String,
134    pub redirect_uri: String,
135    pub scopes: Vec<Scope>,
136    pub user_id: String,
137    pub expires_at: u64,
138    /// PKCE code challenge
139    pub code_challenge: Option<String>,
140    /// PKCE code challenge method
141    pub code_challenge_method: Option<CodeChallengeMethod>,
142}
143
144impl AuthorizationCode {
145    pub fn new(
146        client_id: String,
147        redirect_uri: String,
148        scopes: Vec<Scope>,
149        user_id: String,
150        ttl: Duration,
151        code_challenge: Option<String>,
152        code_challenge_method: Option<CodeChallengeMethod>,
153    ) -> Self {
154        let expires_at = SystemTime::now()
155            .duration_since(UNIX_EPOCH)
156            .unwrap()
157            .as_secs()
158            + ttl.as_secs();
159
160        Self {
161            code: Uuid::new_v4().to_string(),
162            client_id,
163            redirect_uri,
164            scopes,
165            user_id,
166            expires_at,
167            code_challenge,
168            code_challenge_method,
169        }
170    }
171
172    pub fn is_expired(&self) -> bool {
173        let now = SystemTime::now()
174            .duration_since(UNIX_EPOCH)
175            .unwrap()
176            .as_secs();
177        now > self.expires_at
178    }
179
180    /// Verify PKCE code verifier
181    pub fn verify_code_verifier(&self, verifier: &str) -> bool {
182        match (&self.code_challenge, &self.code_challenge_method) {
183            (Some(challenge), Some(method)) => {
184                let computed_challenge = method.compute_challenge(verifier);
185                &computed_challenge == challenge
186            }
187            (None, None) => true, // No PKCE required
188            _ => false,           // Inconsistent PKCE state
189        }
190    }
191}
192
193/// PKCE code challenge method
194#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
195pub enum CodeChallengeMethod {
196    #[serde(rename = "plain")]
197    Plain,
198    #[serde(rename = "S256")]
199    S256,
200}
201
202impl CodeChallengeMethod {
203    /// Compute challenge from verifier
204    pub fn compute_challenge(&self, verifier: &str) -> String {
205        match self {
206            Self::Plain => verifier.to_string(),
207            Self::S256 => {
208                use sha2::{Digest, Sha256};
209                let hash = Sha256::digest(verifier.as_bytes());
210                base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, hash)
211            }
212        }
213    }
214}
215
216/// OAuth2 access token
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct AccessToken {
219    pub token: String,
220    pub token_type: TokenType,
221    pub expires_in: u64,
222    pub scopes: Vec<Scope>,
223    pub user_id: String,
224    pub created_at: u64,
225}
226
227impl AccessToken {
228    pub fn new(token: String, scopes: Vec<Scope>, user_id: String, ttl: Duration) -> Self {
229        Self {
230            token,
231            token_type: TokenType::Bearer,
232            expires_in: ttl.as_secs(),
233            scopes,
234            user_id,
235            created_at: SystemTime::now()
236                .duration_since(UNIX_EPOCH)
237                .unwrap()
238                .as_secs(),
239        }
240    }
241
242    pub fn is_expired(&self) -> bool {
243        let now = SystemTime::now()
244            .duration_since(UNIX_EPOCH)
245            .unwrap()
246            .as_secs();
247        now > self.created_at + self.expires_in
248    }
249}
250
251/// OAuth2 refresh token
252#[derive(Debug, Clone)]
253pub struct RefreshToken {
254    pub token: String,
255    pub client_id: String,
256    pub user_id: String,
257    pub scopes: Vec<Scope>,
258    pub created_at: u64,
259}
260
261impl RefreshToken {
262    pub fn new(client_id: String, user_id: String, scopes: Vec<Scope>) -> Self {
263        Self {
264            token: Uuid::new_v4().to_string(),
265            client_id,
266            user_id,
267            scopes,
268            created_at: SystemTime::now()
269                .duration_since(UNIX_EPOCH)
270                .unwrap()
271                .as_secs(),
272        }
273    }
274}
275
276/// OAuth2 token response
277#[derive(Debug, Serialize, Deserialize)]
278pub struct TokenResponse {
279    pub access_token: String,
280    pub token_type: String,
281    pub expires_in: u64,
282    #[serde(skip_serializing_if = "Option::is_none")]
283    pub refresh_token: Option<String>,
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub scope: Option<String>,
286}
287
288/// OAuth2 error response
289#[derive(Debug, Serialize, Deserialize)]
290pub struct ErrorResponse {
291    pub error: String,
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub error_description: Option<String>,
294}
295
296/// OAuth2 authorization server
297pub struct OAuth2Server {
298    clients: Arc<DashMap<String, OAuth2Client>>,
299    authorization_codes: Arc<DashMap<String, AuthorizationCode>>,
300    access_tokens: Arc<DashMap<String, AccessToken>>,
301    refresh_tokens: Arc<DashMap<String, RefreshToken>>,
302    jwt_manager: Arc<JwtManager>,
303    /// Default access token TTL
304    access_token_ttl: Duration,
305    /// Default refresh token TTL
306    #[allow(dead_code)]
307    refresh_token_ttl: Duration,
308    /// Default authorization code TTL
309    code_ttl: Duration,
310}
311
312impl OAuth2Server {
313    pub fn new(jwt_secret: &[u8]) -> Self {
314        Self {
315            clients: Arc::new(DashMap::new()),
316            authorization_codes: Arc::new(DashMap::new()),
317            access_tokens: Arc::new(DashMap::new()),
318            refresh_tokens: Arc::new(DashMap::new()),
319            jwt_manager: Arc::new(JwtManager::new(jwt_secret)),
320            access_token_ttl: Duration::from_secs(3600), // 1 hour
321            refresh_token_ttl: Duration::from_secs(86400 * 30), // 30 days
322            code_ttl: Duration::from_secs(600),          // 10 minutes
323        }
324    }
325
326    /// Register a new OAuth2 client
327    pub fn register_client(
328        &self,
329        name: String,
330        redirect_uris: Vec<String>,
331        grant_types: Vec<GrantType>,
332        scopes: Vec<Scope>,
333    ) -> OAuth2Client {
334        let client = OAuth2Client::new(name, redirect_uris, grant_types, scopes);
335        self.clients
336            .insert(client.client_id.clone(), client.clone());
337        client
338    }
339
340    /// Get client by ID
341    pub fn get_client(&self, client_id: &str) -> Option<OAuth2Client> {
342        self.clients.get(client_id).map(|c| c.clone())
343    }
344
345    /// Authorize request (Authorization Code Flow)
346    #[allow(clippy::too_many_arguments)]
347    pub fn authorize(
348        &self,
349        client_id: &str,
350        redirect_uri: &str,
351        response_type: ResponseType,
352        scopes: Vec<Scope>,
353        user_id: String,
354        code_challenge: Option<String>,
355        code_challenge_method: Option<CodeChallengeMethod>,
356    ) -> AuthResult<AuthorizationCode> {
357        // Validate client
358        let client = self
359            .get_client(client_id)
360            .ok_or(AuthError::InvalidCredentials)?;
361
362        // Validate redirect URI
363        if !client.is_redirect_uri_allowed(redirect_uri) {
364            return Err(AuthError::InvalidCredentials);
365        }
366
367        // Validate grant type
368        if !client.is_grant_type_allowed(GrantType::AuthorizationCode) {
369            return Err(AuthError::InvalidCredentials);
370        }
371
372        // Validate scopes
373        for scope in &scopes {
374            if !client.is_scope_allowed(scope) {
375                return Err(AuthError::InsufficientPermissions);
376            }
377        }
378
379        // Validate response type
380        if response_type != ResponseType::Code {
381            return Err(AuthError::InvalidCredentials);
382        }
383
384        // Create authorization code
385        let auth_code = AuthorizationCode::new(
386            client_id.to_string(),
387            redirect_uri.to_string(),
388            scopes,
389            user_id,
390            self.code_ttl,
391            code_challenge,
392            code_challenge_method,
393        );
394
395        self.authorization_codes
396            .insert(auth_code.code.clone(), auth_code.clone());
397
398        Ok(auth_code)
399    }
400
401    /// Exchange authorization code for tokens
402    pub fn exchange_code(
403        &self,
404        client_id: &str,
405        client_secret: &str,
406        code: &str,
407        redirect_uri: &str,
408        code_verifier: Option<&str>,
409    ) -> AuthResult<(AccessToken, RefreshToken)> {
410        // Validate client
411        let client = self
412            .get_client(client_id)
413            .ok_or(AuthError::InvalidCredentials)?;
414
415        if !client.verify_secret(client_secret) {
416            return Err(AuthError::InvalidCredentials);
417        }
418
419        // Get and remove authorization code (one-time use)
420        let auth_code = self
421            .authorization_codes
422            .remove(code)
423            .ok_or(AuthError::InvalidToken("Invalid code".to_string()))?
424            .1;
425
426        // Validate code
427        if auth_code.is_expired() {
428            return Err(AuthError::TokenExpired);
429        }
430
431        if auth_code.client_id != client_id {
432            return Err(AuthError::InvalidCredentials);
433        }
434
435        if auth_code.redirect_uri != redirect_uri {
436            return Err(AuthError::InvalidCredentials);
437        }
438
439        // Verify PKCE code verifier if required
440        if let Some(verifier) = code_verifier {
441            if !auth_code.verify_code_verifier(verifier) {
442                return Err(AuthError::InvalidCredentials);
443            }
444        } else if auth_code.code_challenge.is_some() {
445            // Code challenge was provided but no verifier
446            return Err(AuthError::InvalidCredentials);
447        }
448
449        // Generate access token using JWT
450        let access_token_jwt = self
451            .jwt_manager
452            .generate_token_with_scopes(
453                &auth_code.user_id,
454                &Scope::join_scopes(&auth_code.scopes),
455                (self.access_token_ttl.as_secs() / 3600) as usize,
456            )
457            .map_err(|_| AuthError::InvalidToken("Failed to generate token".to_string()))?;
458
459        let access_token = AccessToken::new(
460            access_token_jwt,
461            auth_code.scopes.clone(),
462            auth_code.user_id.clone(),
463            self.access_token_ttl,
464        );
465
466        // Generate refresh token
467        let refresh_token = RefreshToken::new(
468            client_id.to_string(),
469            auth_code.user_id.clone(),
470            auth_code.scopes.clone(),
471        );
472
473        // Store tokens
474        self.access_tokens
475            .insert(access_token.token.clone(), access_token.clone());
476        self.refresh_tokens
477            .insert(refresh_token.token.clone(), refresh_token.clone());
478
479        Ok((access_token, refresh_token))
480    }
481
482    /// Client Credentials Flow
483    pub fn client_credentials(
484        &self,
485        client_id: &str,
486        client_secret: &str,
487        scopes: Vec<Scope>,
488    ) -> AuthResult<AccessToken> {
489        // Validate client
490        let client = self
491            .get_client(client_id)
492            .ok_or(AuthError::InvalidCredentials)?;
493
494        if !client.verify_secret(client_secret) {
495            return Err(AuthError::InvalidCredentials);
496        }
497
498        // Validate grant type
499        if !client.is_grant_type_allowed(GrantType::ClientCredentials) {
500            return Err(AuthError::InvalidCredentials);
501        }
502
503        // Validate scopes
504        for scope in &scopes {
505            if !client.is_scope_allowed(scope) {
506                return Err(AuthError::InsufficientPermissions);
507            }
508        }
509
510        // Generate access token (use client_id as user_id for client credentials)
511        let access_token_jwt = self
512            .jwt_manager
513            .generate_token_with_scopes(
514                client_id,
515                &Scope::join_scopes(&scopes),
516                (self.access_token_ttl.as_secs() / 3600) as usize,
517            )
518            .map_err(|_| AuthError::InvalidToken("Failed to generate token".to_string()))?;
519
520        let access_token = AccessToken::new(
521            access_token_jwt,
522            scopes,
523            client_id.to_string(),
524            self.access_token_ttl,
525        );
526
527        self.access_tokens
528            .insert(access_token.token.clone(), access_token.clone());
529
530        Ok(access_token)
531    }
532
533    /// Refresh access token
534    pub fn refresh_token(
535        &self,
536        client_id: &str,
537        client_secret: &str,
538        refresh_token: &str,
539    ) -> AuthResult<AccessToken> {
540        // Validate client
541        let client = self
542            .get_client(client_id)
543            .ok_or(AuthError::InvalidCredentials)?;
544
545        if !client.verify_secret(client_secret) {
546            return Err(AuthError::InvalidCredentials);
547        }
548
549        // Get refresh token
550        let rt = self
551            .refresh_tokens
552            .get(refresh_token)
553            .ok_or(AuthError::InvalidToken("Invalid refresh token".to_string()))?;
554
555        if rt.client_id != client_id {
556            return Err(AuthError::InvalidCredentials);
557        }
558
559        // Generate new access token
560        let access_token_jwt = self
561            .jwt_manager
562            .generate_token_with_scopes(
563                &rt.user_id,
564                &Scope::join_scopes(&rt.scopes),
565                (self.access_token_ttl.as_secs() / 3600) as usize,
566            )
567            .map_err(|_| AuthError::InvalidToken("Failed to generate token".to_string()))?;
568
569        let access_token = AccessToken::new(
570            access_token_jwt,
571            rt.scopes.clone(),
572            rt.user_id.clone(),
573            self.access_token_ttl,
574        );
575
576        self.access_tokens
577            .insert(access_token.token.clone(), access_token.clone());
578
579        Ok(access_token)
580    }
581
582    /// Validate access token
583    pub fn validate_token(&self, token: &str) -> AuthResult<AccessToken> {
584        let access_token = self
585            .access_tokens
586            .get(token)
587            .ok_or(AuthError::InvalidToken("Token not found".to_string()))?;
588
589        if access_token.is_expired() {
590            // Remove expired token
591            drop(access_token);
592            self.access_tokens.remove(token);
593            return Err(AuthError::TokenExpired);
594        }
595
596        Ok(access_token.clone())
597    }
598
599    /// Revoke access token
600    pub fn revoke_access_token(&self, token: &str) -> bool {
601        self.access_tokens.remove(token).is_some()
602    }
603
604    /// Revoke refresh token
605    pub fn revoke_refresh_token(&self, token: &str) -> bool {
606        self.refresh_tokens.remove(token).is_some()
607    }
608
609    /// Clean up expired tokens and codes
610    pub fn cleanup_expired(&self) {
611        // Clean up expired authorization codes
612        self.authorization_codes
613            .retain(|_, code| !code.is_expired());
614
615        // Clean up expired access tokens
616        self.access_tokens.retain(|_, token| !token.is_expired());
617    }
618}
619
620impl Default for OAuth2Server {
621    fn default() -> Self {
622        Self::new(b"default-secret-change-in-production")
623    }
624}
625
626/// OAuth2 provider configuration for external providers
627#[derive(Debug, Clone, Serialize, Deserialize)]
628pub struct OAuth2ProviderConfig {
629    pub name: String,
630    pub client_id: String,
631    pub client_secret: String,
632    pub authorization_endpoint: String,
633    pub token_endpoint: String,
634    pub redirect_uri: String,
635    pub scopes: Vec<Scope>,
636}
637
638impl OAuth2ProviderConfig {
639    /// Create configuration for Google OAuth2
640    pub fn google(client_id: String, client_secret: String, redirect_uri: String) -> Self {
641        Self {
642            name: "google".to_string(),
643            client_id,
644            client_secret,
645            authorization_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
646            token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
647            redirect_uri,
648            scopes: vec![
649                Scope::new("openid"),
650                Scope::new("email"),
651                Scope::new("profile"),
652            ],
653        }
654    }
655
656    /// Create configuration for GitHub OAuth2
657    pub fn github(client_id: String, client_secret: String, redirect_uri: String) -> Self {
658        Self {
659            name: "github".to_string(),
660            client_id,
661            client_secret,
662            authorization_endpoint: "https://github.com/login/oauth/authorize".to_string(),
663            token_endpoint: "https://github.com/login/oauth/access_token".to_string(),
664            redirect_uri,
665            scopes: vec![Scope::new("user:email"), Scope::new("read:user")],
666        }
667    }
668
669    /// Build authorization URL
670    pub fn build_auth_url(&self, state: &str) -> String {
671        let scope = Scope::join_scopes(&self.scopes);
672        format!(
673            "{}?client_id={}&redirect_uri={}&scope={}&response_type=code&state={}",
674            self.authorization_endpoint,
675            urlencoding::encode(&self.client_id),
676            urlencoding::encode(&self.redirect_uri),
677            urlencoding::encode(&scope),
678            state
679        )
680    }
681}
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686    use std::time::{SystemTime, UNIX_EPOCH};
687
688    #[test]
689    fn test_scope_parsing() {
690        let scopes = Scope::parse_scopes("read write admin");
691        assert_eq!(scopes.len(), 3);
692        assert_eq!(scopes[0].as_str(), "read");
693        assert_eq!(scopes[1].as_str(), "write");
694        assert_eq!(scopes[2].as_str(), "admin");
695    }
696
697    #[test]
698    fn test_scope_joining() {
699        let scopes = vec![Scope::new("read"), Scope::new("write")];
700        let joined = Scope::join_scopes(&scopes);
701        assert_eq!(joined, "read write");
702    }
703
704    #[test]
705    fn test_client_creation() {
706        let client = OAuth2Client::new(
707            "test-client".to_string(),
708            vec!["http://localhost:3000/callback".to_string()],
709            vec![GrantType::AuthorizationCode],
710            vec![Scope::new("read")],
711        );
712
713        assert!(!client.client_id.is_empty());
714        assert!(!client.client_secret.is_empty());
715        assert_eq!(client.name, "test-client");
716    }
717
718    #[test]
719    fn test_client_verification() {
720        let client = OAuth2Client::new(
721            "test".to_string(),
722            vec!["http://localhost/callback".to_string()],
723            vec![GrantType::AuthorizationCode],
724            vec![Scope::new("read")],
725        );
726
727        assert!(client.verify_secret(&client.client_secret));
728        assert!(!client.verify_secret("wrong-secret"));
729        assert!(client.is_redirect_uri_allowed("http://localhost/callback"));
730        assert!(!client.is_redirect_uri_allowed("http://evil.com/callback"));
731    }
732
733    #[test]
734    fn test_pkce_plain() {
735        let method = CodeChallengeMethod::Plain;
736        let verifier = "test-verifier";
737        let challenge = method.compute_challenge(verifier);
738        assert_eq!(challenge, verifier);
739    }
740
741    #[test]
742    fn test_pkce_s256() {
743        let method = CodeChallengeMethod::S256;
744        let verifier = "test-verifier-with-sufficient-entropy";
745        let challenge = method.compute_challenge(verifier);
746        assert_ne!(challenge, verifier);
747        assert!(!challenge.is_empty());
748
749        // Verify consistency
750        let challenge2 = method.compute_challenge(verifier);
751        assert_eq!(challenge, challenge2);
752    }
753
754    #[test]
755    fn test_authorization_code_expiry() {
756        // Create a code that expires 1 second ago
757        let code = AuthorizationCode {
758            code: "test-code".to_string(),
759            client_id: "client-id".to_string(),
760            redirect_uri: "http://localhost/callback".to_string(),
761            scopes: vec![Scope::new("read")],
762            user_id: "user-id".to_string(),
763            expires_at: SystemTime::now()
764                .duration_since(UNIX_EPOCH)
765                .unwrap()
766                .as_secs()
767                - 1, // Expired 1 second ago
768            code_challenge: None,
769            code_challenge_method: None,
770        };
771
772        assert!(code.is_expired());
773    }
774
775    #[test]
776    fn test_oauth2_server_client_registration() {
777        let server = OAuth2Server::default();
778        let client = server.register_client(
779            "test-client".to_string(),
780            vec!["http://localhost/callback".to_string()],
781            vec![GrantType::AuthorizationCode],
782            vec![Scope::new("read")],
783        );
784
785        let retrieved = server.get_client(&client.client_id);
786        assert!(retrieved.is_some());
787        assert_eq!(retrieved.unwrap().name, "test-client");
788    }
789
790    #[test]
791    fn test_oauth2_server_authorization() {
792        let server = OAuth2Server::default();
793        let client = server.register_client(
794            "test".to_string(),
795            vec!["http://localhost/callback".to_string()],
796            vec![GrantType::AuthorizationCode],
797            vec![Scope::new("read")],
798        );
799
800        let auth_code = server
801            .authorize(
802                &client.client_id,
803                "http://localhost/callback",
804                ResponseType::Code,
805                vec![Scope::new("read")],
806                "user-123".to_string(),
807                None,
808                None,
809            )
810            .unwrap();
811
812        assert!(!auth_code.code.is_empty());
813        assert_eq!(auth_code.user_id, "user-123");
814    }
815
816    #[test]
817    fn test_oauth2_server_client_credentials() {
818        let server = OAuth2Server::default();
819        let client = server.register_client(
820            "test".to_string(),
821            vec![],
822            vec![GrantType::ClientCredentials],
823            vec![Scope::new("read")],
824        );
825
826        let token = server
827            .client_credentials(
828                &client.client_id,
829                &client.client_secret,
830                vec![Scope::new("read")],
831            )
832            .unwrap();
833
834        assert!(!token.token.is_empty());
835        assert_eq!(token.token_type, TokenType::Bearer);
836    }
837
838    #[test]
839    fn test_provider_config_google() {
840        let config = OAuth2ProviderConfig::google(
841            "client-id".to_string(),
842            "client-secret".to_string(),
843            "http://localhost/callback".to_string(),
844        );
845
846        assert_eq!(config.name, "google");
847        assert!(config.authorization_endpoint.contains("google"));
848
849        let url = config.build_auth_url("random-state");
850        assert!(url.contains("client_id=client-id"));
851        assert!(url.contains("state=random-state"));
852    }
853
854    #[test]
855    fn test_provider_config_github() {
856        let config = OAuth2ProviderConfig::github(
857            "client-id".to_string(),
858            "client-secret".to_string(),
859            "http://localhost/callback".to_string(),
860        );
861
862        assert_eq!(config.name, "github");
863        assert!(config.authorization_endpoint.contains("github"));
864    }
865
866    #[test]
867    fn test_token_validation() {
868        let server = OAuth2Server::default();
869        let client = server.register_client(
870            "test".to_string(),
871            vec![],
872            vec![GrantType::ClientCredentials],
873            vec![Scope::new("read")],
874        );
875
876        let token = server
877            .client_credentials(
878                &client.client_id,
879                &client.client_secret,
880                vec![Scope::new("read")],
881            )
882            .unwrap();
883
884        // Token should be valid
885        let validated = server.validate_token(&token.token);
886        assert!(validated.is_ok());
887
888        // Revoke token
889        server.revoke_access_token(&token.token);
890
891        // Token should now be invalid
892        let validated = server.validate_token(&token.token);
893        assert!(validated.is_err());
894    }
895}