Skip to main content

actix_security_core/http/security/
oauth2.rs

1//! OAuth2 and OpenID Connect (OIDC) Authentication
2//!
3//! This module provides OAuth2 2.0 and OpenID Connect authentication support,
4//! similar to Spring Security's OAuth2 Login.
5//!
6//! # Features
7//!
8//! - **Authorization Code Flow** - Standard OAuth2 flow for web applications
9//! - **PKCE Support** - Proof Key for Code Exchange for enhanced security
10//! - **OIDC Discovery** - Automatic provider configuration via well-known endpoints
11//! - **Multiple Providers** - Built-in support for Google, GitHub, Microsoft, etc.
12//! - **Custom Providers** - Easy to add custom OAuth2/OIDC providers
13//!
14//! # Quick Start
15//!
16//! ```rust,ignore
17//! use actix_security_core::http::security::oauth2::{
18//!     OAuth2Config, OAuth2Provider, OAuth2Client
19//! };
20//!
21//! // Configure Google OAuth2
22//! let config = OAuth2Config::new(
23//!     "your-client-id",
24//!     "your-client-secret",
25//!     "http://localhost:8080/oauth2/callback/google"
26//! )
27//! .provider(OAuth2Provider::Google)
28//! .scopes(vec!["openid", "email", "profile"]);
29//!
30//! let client = OAuth2Client::new(config).await?;
31//!
32//! // Generate authorization URL
33//! let (auth_url, csrf_token, nonce) = client.authorization_url();
34//! ```
35//!
36//! # Spring Security Comparison
37//!
38//! | Spring Security | Actix Security |
39//! |-----------------|----------------|
40//! | `ClientRegistration` | `OAuth2Config` |
41//! | `ClientRegistrationRepository` | `OAuth2ClientRepository` |
42//! | `OAuth2AuthorizedClient` | `OAuth2Client` |
43//! | `OAuth2User` | `OAuth2User` |
44//! | `OidcUser` | `OidcUser` |
45
46use std::collections::HashMap;
47use std::fmt;
48use std::sync::Arc;
49
50use actix_web::dev::ServiceRequest;
51use actix_web::http::header::AUTHORIZATION;
52use oauth2::basic::BasicClient;
53use oauth2::reqwest::async_http_client;
54use oauth2::{
55    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
56    PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, TokenUrl,
57};
58use openidconnect::core::{
59    CoreAuthenticationFlow, CoreClient, CoreProviderMetadata,
60};
61use openidconnect::{
62    ClientId as OidcClientId, ClientSecret as OidcClientSecret, IssuerUrl, Nonce,
63    RedirectUrl as OidcRedirectUrl, TokenResponse as OidcTokenResponse,
64};
65use serde::{Deserialize, Serialize};
66use url::Url;
67
68use super::config::Authenticator;
69use super::user::User;
70
71/// OAuth2 error types
72#[derive(Debug, Clone)]
73pub enum OAuth2Error {
74    /// Configuration error
75    Configuration(String),
76    /// Provider discovery failed
77    Discovery(String),
78    /// Token exchange failed
79    TokenExchange(String),
80    /// Token validation failed
81    TokenValidation(String),
82    /// User info retrieval failed
83    UserInfo(String),
84    /// Invalid state/CSRF token
85    InvalidState(String),
86    /// Invalid nonce
87    InvalidNonce(String),
88}
89
90impl fmt::Display for OAuth2Error {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        match self {
93            OAuth2Error::Configuration(msg) => write!(f, "Configuration error: {}", msg),
94            OAuth2Error::Discovery(msg) => write!(f, "Discovery error: {}", msg),
95            OAuth2Error::TokenExchange(msg) => write!(f, "Token exchange error: {}", msg),
96            OAuth2Error::TokenValidation(msg) => write!(f, "Token validation error: {}", msg),
97            OAuth2Error::UserInfo(msg) => write!(f, "User info error: {}", msg),
98            OAuth2Error::InvalidState(msg) => write!(f, "Invalid state: {}", msg),
99            OAuth2Error::InvalidNonce(msg) => write!(f, "Invalid nonce: {}", msg),
100        }
101    }
102}
103
104impl std::error::Error for OAuth2Error {}
105
106/// Common OAuth2/OIDC providers with pre-configured endpoints
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
108pub enum OAuth2Provider {
109    /// Google OAuth2/OIDC
110    Google,
111    /// GitHub OAuth2
112    GitHub,
113    /// Microsoft/Azure AD OAuth2/OIDC
114    Microsoft,
115    /// Facebook OAuth2
116    Facebook,
117    /// Apple Sign In
118    Apple,
119    /// Okta OIDC
120    Okta,
121    /// Auth0 OIDC
122    Auth0,
123    /// Keycloak OIDC
124    Keycloak,
125    /// Custom provider (requires manual configuration)
126    Custom,
127}
128
129impl OAuth2Provider {
130    /// Get the OIDC discovery URL for this provider
131    pub fn discovery_url(&self) -> Option<&'static str> {
132        match self {
133            OAuth2Provider::Google => {
134                Some("https://accounts.google.com/.well-known/openid-configuration")
135            }
136            OAuth2Provider::Microsoft => {
137                Some("https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration")
138            }
139            OAuth2Provider::Apple => {
140                Some("https://appleid.apple.com/.well-known/openid-configuration")
141            }
142            _ => None,
143        }
144    }
145
146    /// Get the authorization endpoint for this provider
147    pub fn auth_url(&self) -> Option<&'static str> {
148        match self {
149            OAuth2Provider::Google => Some("https://accounts.google.com/o/oauth2/v2/auth"),
150            OAuth2Provider::GitHub => Some("https://github.com/login/oauth/authorize"),
151            OAuth2Provider::Microsoft => {
152                Some("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
153            }
154            OAuth2Provider::Facebook => Some("https://www.facebook.com/v18.0/dialog/oauth"),
155            OAuth2Provider::Apple => Some("https://appleid.apple.com/auth/authorize"),
156            _ => None,
157        }
158    }
159
160    /// Get the token endpoint for this provider
161    pub fn token_url(&self) -> Option<&'static str> {
162        match self {
163            OAuth2Provider::Google => Some("https://oauth2.googleapis.com/token"),
164            OAuth2Provider::GitHub => Some("https://github.com/login/oauth/access_token"),
165            OAuth2Provider::Microsoft => {
166                Some("https://login.microsoftonline.com/common/oauth2/v2.0/token")
167            }
168            OAuth2Provider::Facebook => {
169                Some("https://graph.facebook.com/v18.0/oauth/access_token")
170            }
171            OAuth2Provider::Apple => Some("https://appleid.apple.com/auth/token"),
172            _ => None,
173        }
174    }
175
176    /// Get the user info endpoint for this provider
177    pub fn userinfo_url(&self) -> Option<&'static str> {
178        match self {
179            OAuth2Provider::Google => Some("https://openidconnect.googleapis.com/v1/userinfo"),
180            OAuth2Provider::GitHub => Some("https://api.github.com/user"),
181            OAuth2Provider::Microsoft => Some("https://graph.microsoft.com/oidc/userinfo"),
182            OAuth2Provider::Facebook => Some("https://graph.facebook.com/me?fields=id,name,email"),
183            _ => None,
184        }
185    }
186
187    /// Get default scopes for this provider
188    pub fn default_scopes(&self) -> Vec<&'static str> {
189        match self {
190            OAuth2Provider::Google => vec!["openid", "email", "profile"],
191            OAuth2Provider::GitHub => vec!["read:user", "user:email"],
192            OAuth2Provider::Microsoft => vec!["openid", "email", "profile"],
193            OAuth2Provider::Facebook => vec!["email", "public_profile"],
194            OAuth2Provider::Apple => vec!["openid", "email", "name"],
195            OAuth2Provider::Okta => vec!["openid", "email", "profile"],
196            OAuth2Provider::Auth0 => vec!["openid", "email", "profile"],
197            OAuth2Provider::Keycloak => vec!["openid", "email", "profile"],
198            OAuth2Provider::Custom => vec!["openid"],
199        }
200    }
201
202    /// Check if this provider supports OIDC
203    pub fn supports_oidc(&self) -> bool {
204        matches!(
205            self,
206            OAuth2Provider::Google
207                | OAuth2Provider::Microsoft
208                | OAuth2Provider::Apple
209                | OAuth2Provider::Okta
210                | OAuth2Provider::Auth0
211                | OAuth2Provider::Keycloak
212        )
213    }
214}
215
216/// OAuth2 configuration for a client registration
217///
218/// Similar to Spring Security's `ClientRegistration`.
219///
220/// # Example
221///
222/// ```rust,ignore
223/// let config = OAuth2Config::new(
224///     "client-id",
225///     "client-secret",
226///     "http://localhost:8080/oauth2/callback/google"
227/// )
228/// .provider(OAuth2Provider::Google)
229/// .scopes(vec!["openid", "email", "profile"]);
230/// ```
231#[derive(Debug, Clone)]
232pub struct OAuth2Config {
233    /// Registration ID (e.g., "google", "github")
234    pub registration_id: String,
235    /// OAuth2 client ID
236    pub client_id: String,
237    /// OAuth2 client secret
238    pub client_secret: String,
239    /// Redirect URI for callbacks
240    pub redirect_uri: String,
241    /// OAuth2 provider
242    pub provider: OAuth2Provider,
243    /// Authorization endpoint URL (optional, auto-discovered for OIDC)
244    pub authorization_uri: Option<String>,
245    /// Token endpoint URL (optional, auto-discovered for OIDC)
246    pub token_uri: Option<String>,
247    /// User info endpoint URL (optional, auto-discovered for OIDC)
248    pub userinfo_uri: Option<String>,
249    /// OIDC issuer URL (for discovery)
250    pub issuer_uri: Option<String>,
251    /// JWK Set URI (for ID token validation)
252    pub jwk_set_uri: Option<String>,
253    /// OAuth2 scopes
254    pub scopes: Vec<String>,
255    /// Use PKCE (Proof Key for Code Exchange)
256    pub use_pkce: bool,
257    /// Custom parameters for authorization request
258    pub authorization_params: HashMap<String, String>,
259    /// Attribute name for username extraction
260    pub username_attribute: String,
261}
262
263impl OAuth2Config {
264    /// Create a new OAuth2 configuration
265    ///
266    /// # Arguments
267    ///
268    /// * `client_id` - The OAuth2 client ID
269    /// * `client_secret` - The OAuth2 client secret
270    /// * `redirect_uri` - The callback URL for authorization response
271    pub fn new(
272        client_id: impl Into<String>,
273        client_secret: impl Into<String>,
274        redirect_uri: impl Into<String>,
275    ) -> Self {
276        Self {
277            registration_id: String::new(),
278            client_id: client_id.into(),
279            client_secret: client_secret.into(),
280            redirect_uri: redirect_uri.into(),
281            provider: OAuth2Provider::Custom,
282            authorization_uri: None,
283            token_uri: None,
284            userinfo_uri: None,
285            issuer_uri: None,
286            jwk_set_uri: None,
287            scopes: vec!["openid".to_string()],
288            use_pkce: true,
289            authorization_params: HashMap::new(),
290            username_attribute: "sub".to_string(),
291        }
292    }
293
294    /// Set the registration ID
295    pub fn registration_id(mut self, id: impl Into<String>) -> Self {
296        self.registration_id = id.into();
297        self
298    }
299
300    /// Set the OAuth2 provider
301    ///
302    /// This will auto-configure endpoints for known providers.
303    pub fn provider(mut self, provider: OAuth2Provider) -> Self {
304        self.provider = provider;
305        if self.registration_id.is_empty() {
306            self.registration_id = format!("{:?}", provider).to_lowercase();
307        }
308
309        // Set default endpoints from provider
310        if let Some(auth_url) = provider.auth_url() {
311            self.authorization_uri = Some(auth_url.to_string());
312        }
313        if let Some(token_url) = provider.token_url() {
314            self.token_uri = Some(token_url.to_string());
315        }
316        if let Some(userinfo_url) = provider.userinfo_url() {
317            self.userinfo_uri = Some(userinfo_url.to_string());
318        }
319
320        // Set default scopes
321        if self.scopes.len() == 1 && self.scopes[0] == "openid" {
322            self.scopes = provider
323                .default_scopes()
324                .into_iter()
325                .map(String::from)
326                .collect();
327        }
328
329        // GitHub doesn't support PKCE
330        if matches!(provider, OAuth2Provider::GitHub | OAuth2Provider::Facebook) {
331            self.use_pkce = false;
332        }
333
334        self
335    }
336
337    /// Set the authorization endpoint URL
338    pub fn authorization_uri(mut self, uri: impl Into<String>) -> Self {
339        self.authorization_uri = Some(uri.into());
340        self
341    }
342
343    /// Set the token endpoint URL
344    pub fn token_uri(mut self, uri: impl Into<String>) -> Self {
345        self.token_uri = Some(uri.into());
346        self
347    }
348
349    /// Set the user info endpoint URL
350    pub fn userinfo_uri(mut self, uri: impl Into<String>) -> Self {
351        self.userinfo_uri = Some(uri.into());
352        self
353    }
354
355    /// Set the OIDC issuer URL for auto-discovery
356    pub fn issuer_uri(mut self, uri: impl Into<String>) -> Self {
357        self.issuer_uri = Some(uri.into());
358        self
359    }
360
361    /// Set the JWK Set URI for ID token validation
362    pub fn jwk_set_uri(mut self, uri: impl Into<String>) -> Self {
363        self.jwk_set_uri = Some(uri.into());
364        self
365    }
366
367    /// Set the OAuth2 scopes
368    pub fn scopes(mut self, scopes: Vec<impl Into<String>>) -> Self {
369        self.scopes = scopes.into_iter().map(|s| s.into()).collect();
370        self
371    }
372
373    /// Add a scope
374    pub fn add_scope(mut self, scope: impl Into<String>) -> Self {
375        self.scopes.push(scope.into());
376        self
377    }
378
379    /// Enable or disable PKCE
380    pub fn use_pkce(mut self, use_pkce: bool) -> Self {
381        self.use_pkce = use_pkce;
382        self
383    }
384
385    /// Add a custom authorization parameter
386    pub fn authorization_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
387        self.authorization_params.insert(key.into(), value.into());
388        self
389    }
390
391    /// Set the attribute name used for extracting the username
392    pub fn username_attribute(mut self, attr: impl Into<String>) -> Self {
393        self.username_attribute = attr.into();
394        self
395    }
396}
397
398/// User information retrieved from OAuth2 provider
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct OAuth2User {
401    /// User's unique identifier (subject)
402    pub sub: String,
403    /// User's name
404    pub name: Option<String>,
405    /// User's given name
406    pub given_name: Option<String>,
407    /// User's family name
408    pub family_name: Option<String>,
409    /// User's email
410    pub email: Option<String>,
411    /// Whether email is verified
412    pub email_verified: Option<bool>,
413    /// User's picture URL
414    pub picture: Option<String>,
415    /// User's locale
416    pub locale: Option<String>,
417    /// Provider-specific attributes
418    #[serde(flatten)]
419    pub attributes: HashMap<String, serde_json::Value>,
420    /// OAuth2 access token
421    #[serde(skip)]
422    pub access_token: Option<String>,
423    /// OAuth2 refresh token
424    #[serde(skip)]
425    pub refresh_token: Option<String>,
426    /// Token expiration time (Unix timestamp)
427    pub expires_at: Option<i64>,
428    /// Provider that authenticated this user
429    pub provider: String,
430}
431
432impl OAuth2User {
433    /// Create a new OAuth2User with minimal information
434    pub fn new(sub: impl Into<String>, provider: impl Into<String>) -> Self {
435        Self {
436            sub: sub.into(),
437            name: None,
438            given_name: None,
439            family_name: None,
440            email: None,
441            email_verified: None,
442            picture: None,
443            locale: None,
444            attributes: HashMap::new(),
445            access_token: None,
446            refresh_token: None,
447            expires_at: None,
448            provider: provider.into(),
449        }
450    }
451
452    /// Get a specific attribute value
453    pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
454        self.attributes.get(key)
455    }
456
457    /// Get the username (tries email first, then sub)
458    pub fn username(&self) -> &str {
459        self.email.as_deref().unwrap_or(&self.sub)
460    }
461
462    /// Convert to a User for authentication
463    pub fn to_user(&self) -> User {
464        User::new(self.username().to_string(), String::new())
465            .roles(&["USER".to_string()])
466            .authorities(&[format!("OAUTH2_USER_{}", self.provider.to_uppercase())])
467    }
468}
469
470/// OIDC user with ID token claims
471#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct OidcUser {
473    /// Base OAuth2 user info
474    #[serde(flatten)]
475    pub oauth2_user: OAuth2User,
476    /// ID token claims
477    pub id_token_claims: Option<IdTokenClaims>,
478    /// Raw ID token (JWT)
479    #[serde(skip)]
480    pub id_token: Option<String>,
481}
482
483/// ID Token claims
484#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct IdTokenClaims {
486    /// Issuer
487    pub iss: String,
488    /// Subject
489    pub sub: String,
490    /// Audience
491    pub aud: Vec<String>,
492    /// Expiration time
493    pub exp: i64,
494    /// Issued at time
495    pub iat: i64,
496    /// Authentication time
497    pub auth_time: Option<i64>,
498    /// Nonce
499    pub nonce: Option<String>,
500    /// Access token hash
501    pub at_hash: Option<String>,
502}
503
504impl OidcUser {
505    /// Convert to a User for authentication
506    pub fn to_user(&self) -> User {
507        self.oauth2_user.to_user()
508    }
509}
510
511/// Authorization request state (stored in session)
512#[derive(Debug, Clone, Serialize, Deserialize)]
513pub struct AuthorizationRequestState {
514    /// CSRF state token
515    pub state: String,
516    /// PKCE code verifier (if PKCE is used)
517    pub pkce_verifier: Option<String>,
518    /// OIDC nonce (if OIDC is used)
519    pub nonce: Option<String>,
520    /// Redirect URL after successful authentication
521    pub redirect_uri: Option<String>,
522    /// Provider registration ID
523    pub registration_id: String,
524    /// Timestamp when the request was created
525    pub created_at: i64,
526}
527
528/// OAuth2 client for handling authorization flows
529///
530/// Similar to Spring Security's `OAuth2AuthorizedClientService`.
531#[derive(Clone)]
532pub struct OAuth2Client {
533    config: OAuth2Config,
534    oauth2_client: BasicClient,
535    oidc_client: Option<Arc<CoreClient>>,
536}
537
538impl OAuth2Client {
539    /// Create a new OAuth2 client from configuration
540    ///
541    /// For OIDC providers, this will perform discovery to fetch provider metadata.
542    pub async fn new(config: OAuth2Config) -> Result<Self, OAuth2Error> {
543        // Build the basic OAuth2 client
544        let auth_url = config
545            .authorization_uri
546            .as_ref()
547            .ok_or_else(|| OAuth2Error::Configuration("Missing authorization URI".to_string()))?;
548
549        let token_url = config
550            .token_uri
551            .as_ref()
552            .ok_or_else(|| OAuth2Error::Configuration("Missing token URI".to_string()))?;
553
554        let oauth2_client = BasicClient::new(
555            ClientId::new(config.client_id.clone()),
556            Some(ClientSecret::new(config.client_secret.clone())),
557            AuthUrl::new(auth_url.clone())
558                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
559            Some(
560                TokenUrl::new(token_url.clone())
561                    .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
562            ),
563        )
564        .set_redirect_uri(
565            RedirectUrl::new(config.redirect_uri.clone())
566                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
567        );
568
569        // For OIDC providers, try to create an OIDC client
570        let oidc_client = if config.provider.supports_oidc() {
571            if let Some(issuer_uri) = &config.issuer_uri {
572                match Self::create_oidc_client(&config, issuer_uri).await {
573                    Ok(client) => Some(Arc::new(client)),
574                    Err(e) => {
575                        // Log warning but continue without OIDC
576                        eprintln!("Warning: OIDC discovery failed: {}", e);
577                        None
578                    }
579                }
580            } else if let Some(discovery_url) = config.provider.discovery_url() {
581                // Extract issuer from discovery URL
582                let issuer = discovery_url.trim_end_matches("/.well-known/openid-configuration");
583                match Self::create_oidc_client(&config, issuer).await {
584                    Ok(client) => Some(Arc::new(client)),
585                    Err(e) => {
586                        eprintln!("Warning: OIDC discovery failed: {}", e);
587                        None
588                    }
589                }
590            } else {
591                None
592            }
593        } else {
594            None
595        };
596
597        Ok(Self {
598            config,
599            oauth2_client,
600            oidc_client,
601        })
602    }
603
604    /// Create a new OAuth2 client without OIDC discovery (sync)
605    ///
606    /// Use this when you don't need OIDC features or want to avoid async initialization.
607    pub fn new_basic(config: OAuth2Config) -> Result<Self, OAuth2Error> {
608        let auth_url = config
609            .authorization_uri
610            .as_ref()
611            .ok_or_else(|| OAuth2Error::Configuration("Missing authorization URI".to_string()))?;
612
613        let token_url = config
614            .token_uri
615            .as_ref()
616            .ok_or_else(|| OAuth2Error::Configuration("Missing token URI".to_string()))?;
617
618        let oauth2_client = BasicClient::new(
619            ClientId::new(config.client_id.clone()),
620            Some(ClientSecret::new(config.client_secret.clone())),
621            AuthUrl::new(auth_url.clone())
622                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
623            Some(
624                TokenUrl::new(token_url.clone())
625                    .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
626            ),
627        )
628        .set_redirect_uri(
629            RedirectUrl::new(config.redirect_uri.clone())
630                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
631        );
632
633        Ok(Self {
634            config,
635            oauth2_client,
636            oidc_client: None,
637        })
638    }
639
640    /// Create an OIDC client with discovery
641    async fn create_oidc_client(
642        config: &OAuth2Config,
643        issuer_uri: &str,
644    ) -> Result<CoreClient, OAuth2Error> {
645        let issuer_url = IssuerUrl::new(issuer_uri.to_string())
646            .map_err(|e| OAuth2Error::Configuration(e.to_string()))?;
647
648        // Discover provider metadata using openidconnect's async http client
649        let provider_metadata = CoreProviderMetadata::discover_async(
650            issuer_url,
651            openidconnect::reqwest::async_http_client,
652        )
653        .await
654        .map_err(|e| OAuth2Error::Discovery(format!("{:?}", e)))?;
655
656        // Build OIDC client
657        let client = CoreClient::from_provider_metadata(
658            provider_metadata,
659            OidcClientId::new(config.client_id.clone()),
660            Some(OidcClientSecret::new(config.client_secret.clone())),
661        )
662        .set_redirect_uri(
663            OidcRedirectUrl::new(config.redirect_uri.clone())
664                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
665        );
666
667        Ok(client)
668    }
669
670    /// Generate an authorization URL for the OAuth2 flow
671    ///
672    /// Returns (authorization_url, state, pkce_verifier, nonce)
673    pub fn authorization_url(
674        &self,
675    ) -> (Url, CsrfToken, Option<PkceCodeVerifier>, Option<Nonce>) {
676        if let Some(oidc_client) = &self.oidc_client {
677            // OIDC flow with nonce
678            let mut auth_request = oidc_client.authorize_url(
679                CoreAuthenticationFlow::AuthorizationCode,
680                CsrfToken::new_random,
681                Nonce::new_random,
682            );
683
684            // Add scopes
685            for scope in &self.config.scopes {
686                auth_request = auth_request.add_scope(openidconnect::Scope::new(scope.clone()));
687            }
688
689            // Add PKCE if enabled
690            let pkce_verifier = if self.config.use_pkce {
691                let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
692                auth_request = auth_request.set_pkce_challenge(pkce_challenge);
693                Some(pkce_verifier)
694            } else {
695                None
696            };
697
698            let (url, state, nonce) = auth_request.url();
699            (url, state, pkce_verifier, Some(nonce))
700        } else {
701            // Standard OAuth2 flow
702            let mut auth_request = self.oauth2_client.authorize_url(CsrfToken::new_random);
703
704            // Add scopes
705            for scope in &self.config.scopes {
706                auth_request = auth_request.add_scope(Scope::new(scope.clone()));
707            }
708
709            // Add PKCE if enabled
710            let pkce_verifier = if self.config.use_pkce {
711                let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
712                auth_request = auth_request.set_pkce_challenge(pkce_challenge);
713                Some(pkce_verifier)
714            } else {
715                None
716            };
717
718            let (url, state) = auth_request.url();
719            (url, state, pkce_verifier, None)
720        }
721    }
722
723    /// Exchange authorization code for tokens
724    pub async fn exchange_code(
725        &self,
726        code: &str,
727        pkce_verifier: Option<PkceCodeVerifier>,
728        nonce: Option<&Nonce>,
729    ) -> Result<(OAuth2User, Option<OidcUser>), OAuth2Error> {
730        if let Some(oidc_client) = &self.oidc_client {
731            // OIDC token exchange using openidconnect's async client
732            let mut token_request =
733                oidc_client.exchange_code(AuthorizationCode::new(code.to_string()));
734
735            if let Some(verifier) = pkce_verifier {
736                token_request = token_request.set_pkce_verifier(verifier);
737            }
738
739            let token_response = token_request
740                .request_async(openidconnect::reqwest::async_http_client)
741                .await
742                .map_err(|e| OAuth2Error::TokenExchange(format!("{:?}", e)))?;
743
744            // Verify and extract ID token claims
745            let id_token = token_response
746                .id_token()
747                .ok_or_else(|| OAuth2Error::TokenValidation("Missing ID token".to_string()))?;
748
749            let id_token_verifier = oidc_client.id_token_verifier();
750            let nonce_ref = nonce.cloned().unwrap_or_else(|| Nonce::new(String::new()));
751            let claims = id_token
752                .claims(&id_token_verifier, &nonce_ref)
753                .map_err(|e: openidconnect::ClaimsVerificationError| {
754                    OAuth2Error::TokenValidation(e.to_string())
755                })?;
756
757            // Extract basic info from claims
758            let subject = claims.subject().as_str().to_string();
759            let issuer = claims.issuer().as_str().to_string();
760            let exp = claims.expiration().timestamp();
761            let iat = claims.issue_time().timestamp();
762
763            // Build OAuth2User
764            let mut oauth2_user = OAuth2User::new(&subject, &self.config.registration_id);
765            oauth2_user.access_token = Some(token_response.access_token().secret().clone());
766            oauth2_user.refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
767            oauth2_user.email_verified = claims.email_verified();
768
769            // Build OidcUser with basic claims
770            let oidc_user = OidcUser {
771                oauth2_user: oauth2_user.clone(),
772                id_token_claims: Some(IdTokenClaims {
773                    iss: issuer,
774                    sub: subject,
775                    aud: vec![self.config.client_id.clone()],
776                    exp,
777                    iat,
778                    auth_time: None,
779                    nonce: None,
780                    at_hash: None,
781                }),
782                id_token: Some(id_token.to_string()),
783            };
784
785            Ok((oauth2_user, Some(oidc_user)))
786        } else {
787            // Standard OAuth2 token exchange
788            let mut token_request = self
789                .oauth2_client
790                .exchange_code(AuthorizationCode::new(code.to_string()));
791
792            if let Some(verifier) = pkce_verifier {
793                token_request = token_request.set_pkce_verifier(verifier);
794            }
795
796            let token_response = token_request
797                .request_async(async_http_client)
798                .await
799                .map_err(|e| OAuth2Error::TokenExchange(e.to_string()))?;
800
801            // Fetch user info
802            let mut oauth2_user = self
803                .fetch_user_info(token_response.access_token().secret())
804                .await?;
805
806            oauth2_user.access_token = Some(token_response.access_token().secret().clone());
807            oauth2_user.refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
808
809            Ok((oauth2_user, None))
810        }
811    }
812
813    /// Fetch user info from the provider's userinfo endpoint
814    async fn fetch_user_info(&self, access_token: &str) -> Result<OAuth2User, OAuth2Error> {
815        let userinfo_url = self
816            .config
817            .userinfo_uri
818            .as_ref()
819            .ok_or_else(|| OAuth2Error::UserInfo("Missing userinfo URI".to_string()))?;
820
821        let http_client = reqwest::Client::new();
822        let response = http_client
823            .get(userinfo_url)
824            .bearer_auth(access_token)
825            .header("Accept", "application/json")
826            .send()
827            .await
828            .map_err(|e| OAuth2Error::UserInfo(e.to_string()))?;
829
830        if !response.status().is_success() {
831            return Err(OAuth2Error::UserInfo(format!(
832                "HTTP {}: {}",
833                response.status(),
834                response.text().await.unwrap_or_default()
835            )));
836        }
837
838        let attributes: HashMap<String, serde_json::Value> = response
839            .json()
840            .await
841            .map_err(|e| OAuth2Error::UserInfo(e.to_string()))?;
842
843        // Extract user ID based on provider
844        let sub = self.extract_user_id(&attributes)?;
845        let mut user = OAuth2User::new(sub, &self.config.registration_id);
846        user.access_token = Some(access_token.to_string());
847        user.attributes = attributes.clone();
848
849        // Extract common fields
850        user.name = attributes
851            .get("name")
852            .and_then(|v| v.as_str())
853            .map(String::from);
854        user.email = attributes
855            .get("email")
856            .and_then(|v| v.as_str())
857            .map(String::from);
858        user.picture = attributes
859            .get("picture")
860            .or_else(|| attributes.get("avatar_url"))
861            .and_then(|v| v.as_str())
862            .map(String::from);
863
864        Ok(user)
865    }
866
867    /// Extract user ID from attributes based on provider
868    fn extract_user_id(
869        &self,
870        attributes: &HashMap<String, serde_json::Value>,
871    ) -> Result<String, OAuth2Error> {
872        // Try common ID fields
873        let id_fields = ["sub", "id", "user_id", "login"];
874
875        for field in &id_fields {
876            if let Some(value) = attributes.get(*field) {
877                if let Some(s) = value.as_str() {
878                    return Ok(s.to_string());
879                }
880                if let Some(n) = value.as_i64() {
881                    return Ok(n.to_string());
882                }
883            }
884        }
885
886        Err(OAuth2Error::UserInfo(
887            "Could not extract user ID".to_string(),
888        ))
889    }
890
891    /// Get the configuration
892    pub fn config(&self) -> &OAuth2Config {
893        &self.config
894    }
895
896    /// Check if OIDC is available
897    pub fn has_oidc(&self) -> bool {
898        self.oidc_client.is_some()
899    }
900}
901
902/// Repository for multiple OAuth2 client registrations
903///
904/// Similar to Spring Security's `ClientRegistrationRepository`.
905#[derive(Clone, Default)]
906pub struct OAuth2ClientRepository {
907    clients: HashMap<String, OAuth2Client>,
908}
909
910impl OAuth2ClientRepository {
911    /// Create a new empty repository
912    pub fn new() -> Self {
913        Self {
914            clients: HashMap::new(),
915        }
916    }
917
918    /// Add a client registration
919    pub fn add_client(&mut self, client: OAuth2Client) {
920        self.clients
921            .insert(client.config.registration_id.clone(), client);
922    }
923
924    /// Get a client by registration ID
925    pub fn get_client(&self, registration_id: &str) -> Option<&OAuth2Client> {
926        self.clients.get(registration_id)
927    }
928
929    /// Get all registration IDs
930    pub fn registration_ids(&self) -> Vec<&String> {
931        self.clients.keys().collect()
932    }
933
934    /// Build a repository from multiple configurations
935    pub async fn from_configs(configs: Vec<OAuth2Config>) -> Result<Self, OAuth2Error> {
936        let mut repo = Self::new();
937        for config in configs {
938            let client = OAuth2Client::new(config).await?;
939            repo.add_client(client);
940        }
941        Ok(repo)
942    }
943}
944
945/// OAuth2 authenticator that validates OAuth2 access tokens
946///
947/// This authenticator checks for Bearer tokens in the Authorization header
948/// and validates them against the OAuth2 provider.
949#[derive(Clone)]
950pub struct OAuth2Authenticator {
951    /// Expected issuer for token validation
952    issuer: Option<String>,
953    /// JWKS for token validation
954    jwks_uri: Option<String>,
955    /// Attribute to use as username
956    username_attribute: String,
957}
958
959impl OAuth2Authenticator {
960    /// Create a new OAuth2 authenticator
961    pub fn new() -> Self {
962        Self {
963            issuer: None,
964            jwks_uri: None,
965            username_attribute: "sub".to_string(),
966        }
967    }
968
969    /// Set the expected issuer
970    pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
971        self.issuer = Some(issuer.into());
972        self
973    }
974
975    /// Set the JWKS URI for token validation
976    pub fn jwks_uri(mut self, uri: impl Into<String>) -> Self {
977        self.jwks_uri = Some(uri.into());
978        self
979    }
980
981    /// Set the attribute to use as username
982    pub fn username_attribute(mut self, attr: impl Into<String>) -> Self {
983        self.username_attribute = attr.into();
984        self
985    }
986
987    /// Extract Bearer token from request
988    fn extract_token(&self, req: &ServiceRequest) -> Option<String> {
989        let auth_header = req.headers().get(AUTHORIZATION)?;
990        let auth_str = auth_header.to_str().ok()?;
991
992        auth_str
993            .strip_prefix("Bearer ")
994            .map(|token| token.to_string())
995    }
996}
997
998impl Default for OAuth2Authenticator {
999    fn default() -> Self {
1000        Self::new()
1001    }
1002}
1003
1004impl Authenticator for OAuth2Authenticator {
1005    fn get_user(&self, req: &ServiceRequest) -> Option<User> {
1006        // This is a simplified implementation
1007        // In production, you would validate the token against JWKS
1008        let _token = self.extract_token(req)?;
1009
1010        // For now, we return None as token validation requires async
1011        // The actual validation should be done via OAuth2CallbackHandler
1012        None
1013    }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018    use super::*;
1019
1020    #[test]
1021    fn test_oauth2_config_builder() {
1022        let config = OAuth2Config::new("client-id", "secret", "http://localhost/callback")
1023            .provider(OAuth2Provider::Google)
1024            .add_scope("custom_scope");
1025
1026        assert_eq!(config.client_id, "client-id");
1027        assert_eq!(config.provider, OAuth2Provider::Google);
1028        assert!(config.scopes.contains(&"openid".to_string()));
1029        assert!(config.scopes.contains(&"email".to_string()));
1030        assert!(config.scopes.contains(&"custom_scope".to_string()));
1031        assert!(config.use_pkce);
1032    }
1033
1034    #[test]
1035    fn test_oauth2_provider_endpoints() {
1036        assert!(OAuth2Provider::Google.auth_url().is_some());
1037        assert!(OAuth2Provider::Google.token_url().is_some());
1038        assert!(OAuth2Provider::Google.userinfo_url().is_some());
1039        assert!(OAuth2Provider::Google.supports_oidc());
1040
1041        assert!(OAuth2Provider::GitHub.auth_url().is_some());
1042        assert!(!OAuth2Provider::GitHub.supports_oidc());
1043    }
1044
1045    #[test]
1046    fn test_oauth2_user() {
1047        let mut user = OAuth2User::new("user123", "google");
1048        user.email = Some("user@example.com".to_string());
1049        user.name = Some("Test User".to_string());
1050
1051        assert_eq!(user.username(), "user@example.com");
1052
1053        let auth_user = user.to_user();
1054        assert_eq!(auth_user.get_username(), "user@example.com");
1055        assert!(auth_user.get_roles().contains(&"USER".to_string()));
1056    }
1057
1058    #[test]
1059    fn test_provider_default_scopes() {
1060        let google_scopes = OAuth2Provider::Google.default_scopes();
1061        assert!(google_scopes.contains(&"openid"));
1062        assert!(google_scopes.contains(&"email"));
1063
1064        let github_scopes = OAuth2Provider::GitHub.default_scopes();
1065        assert!(github_scopes.contains(&"read:user"));
1066    }
1067
1068    #[test]
1069    fn test_oauth2_client_basic() {
1070        let config = OAuth2Config::new("client-id", "secret", "http://localhost/callback")
1071            .provider(OAuth2Provider::GitHub);
1072
1073        let client = OAuth2Client::new_basic(config).unwrap();
1074        assert!(!client.has_oidc());
1075    }
1076}