Skip to main content

actix_security_core/http/security/
oauth2.rs

1//! OAuth2 and OpenID Connect (OIDC) Authentication
2//!
3//! This module provides OAuth2 2.0 and OpenID Connect authentication support,
4//! similar to Spring Security's OAuth2 Login.
5//!
6//! # Features
7//!
8//! - **Authorization Code Flow** - Standard OAuth2 flow for web applications
9//! - **PKCE Support** - Proof Key for Code Exchange for enhanced security
10//! - **OIDC Discovery** - Automatic provider configuration via well-known endpoints
11//! - **Multiple Providers** - Built-in support for Google, GitHub, Microsoft, etc.
12//! - **Custom Providers** - Easy to add custom OAuth2/OIDC providers
13//!
14//! # Quick Start
15//!
16//! ```rust,ignore
17//! use actix_security_core::http::security::oauth2::{
18//!     OAuth2Config, OAuth2Provider, OAuth2Client
19//! };
20//!
21//! // Configure Google OAuth2
22//! let config = OAuth2Config::new(
23//!     "your-client-id",
24//!     "your-client-secret",
25//!     "http://localhost:8080/oauth2/callback/google"
26//! )
27//! .provider(OAuth2Provider::Google)
28//! .scopes(vec!["openid", "email", "profile"]);
29//!
30//! let client = OAuth2Client::new(config).await?;
31//!
32//! // Generate authorization URL
33//! let (auth_url, csrf_token, nonce) = client.authorization_url();
34//! ```
35//!
36//! # Spring Security Comparison
37//!
38//! | Spring Security | Actix Security |
39//! |-----------------|----------------|
40//! | `ClientRegistration` | `OAuth2Config` |
41//! | `ClientRegistrationRepository` | `OAuth2ClientRepository` |
42//! | `OAuth2AuthorizedClient` | `OAuth2Client` |
43//! | `OAuth2User` | `OAuth2User` |
44//! | `OidcUser` | `OidcUser` |
45
46use std::collections::HashMap;
47use std::fmt;
48use std::sync::Arc;
49
50use actix_web::dev::ServiceRequest;
51use actix_web::http::header::AUTHORIZATION;
52use oauth2::basic::BasicClient;
53use oauth2::reqwest::async_http_client;
54use oauth2::{
55    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
56    PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, TokenUrl,
57};
58use openidconnect::core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata};
59use openidconnect::{
60    ClientId as OidcClientId, ClientSecret as OidcClientSecret, IssuerUrl, Nonce,
61    RedirectUrl as OidcRedirectUrl, TokenResponse as OidcTokenResponse,
62};
63use serde::{Deserialize, Serialize};
64use url::Url;
65
66use super::config::Authenticator;
67use super::user::User;
68
69/// OAuth2 error types
70#[derive(Debug, Clone)]
71pub enum OAuth2Error {
72    /// Configuration error
73    Configuration(String),
74    /// Provider discovery failed
75    Discovery(String),
76    /// Token exchange failed
77    TokenExchange(String),
78    /// Token validation failed
79    TokenValidation(String),
80    /// User info retrieval failed
81    UserInfo(String),
82    /// Invalid state/CSRF token
83    InvalidState(String),
84    /// Invalid nonce
85    InvalidNonce(String),
86}
87
88impl fmt::Display for OAuth2Error {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        match self {
91            OAuth2Error::Configuration(msg) => write!(f, "Configuration error: {}", msg),
92            OAuth2Error::Discovery(msg) => write!(f, "Discovery error: {}", msg),
93            OAuth2Error::TokenExchange(msg) => write!(f, "Token exchange error: {}", msg),
94            OAuth2Error::TokenValidation(msg) => write!(f, "Token validation error: {}", msg),
95            OAuth2Error::UserInfo(msg) => write!(f, "User info error: {}", msg),
96            OAuth2Error::InvalidState(msg) => write!(f, "Invalid state: {}", msg),
97            OAuth2Error::InvalidNonce(msg) => write!(f, "Invalid nonce: {}", msg),
98        }
99    }
100}
101
102impl std::error::Error for OAuth2Error {}
103
104/// Common OAuth2/OIDC providers with pre-configured endpoints
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
106pub enum OAuth2Provider {
107    /// Google OAuth2/OIDC
108    Google,
109    /// GitHub OAuth2
110    GitHub,
111    /// Microsoft/Azure AD OAuth2/OIDC
112    Microsoft,
113    /// Facebook OAuth2
114    Facebook,
115    /// Apple Sign In
116    Apple,
117    /// Okta OIDC
118    Okta,
119    /// Auth0 OIDC
120    Auth0,
121    /// Keycloak OIDC
122    Keycloak,
123    /// Custom provider (requires manual configuration)
124    Custom,
125}
126
127impl OAuth2Provider {
128    /// Get the OIDC discovery URL for this provider
129    pub fn discovery_url(&self) -> Option<&'static str> {
130        match self {
131            OAuth2Provider::Google => {
132                Some("https://accounts.google.com/.well-known/openid-configuration")
133            }
134            OAuth2Provider::Microsoft => Some(
135                "https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration",
136            ),
137            OAuth2Provider::Apple => {
138                Some("https://appleid.apple.com/.well-known/openid-configuration")
139            }
140            _ => None,
141        }
142    }
143
144    /// Get the authorization endpoint for this provider
145    pub fn auth_url(&self) -> Option<&'static str> {
146        match self {
147            OAuth2Provider::Google => Some("https://accounts.google.com/o/oauth2/v2/auth"),
148            OAuth2Provider::GitHub => Some("https://github.com/login/oauth/authorize"),
149            OAuth2Provider::Microsoft => {
150                Some("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
151            }
152            OAuth2Provider::Facebook => Some("https://www.facebook.com/v18.0/dialog/oauth"),
153            OAuth2Provider::Apple => Some("https://appleid.apple.com/auth/authorize"),
154            _ => None,
155        }
156    }
157
158    /// Get the token endpoint for this provider
159    pub fn token_url(&self) -> Option<&'static str> {
160        match self {
161            OAuth2Provider::Google => Some("https://oauth2.googleapis.com/token"),
162            OAuth2Provider::GitHub => Some("https://github.com/login/oauth/access_token"),
163            OAuth2Provider::Microsoft => {
164                Some("https://login.microsoftonline.com/common/oauth2/v2.0/token")
165            }
166            OAuth2Provider::Facebook => Some("https://graph.facebook.com/v18.0/oauth/access_token"),
167            OAuth2Provider::Apple => Some("https://appleid.apple.com/auth/token"),
168            _ => None,
169        }
170    }
171
172    /// Get the user info endpoint for this provider
173    pub fn userinfo_url(&self) -> Option<&'static str> {
174        match self {
175            OAuth2Provider::Google => Some("https://openidconnect.googleapis.com/v1/userinfo"),
176            OAuth2Provider::GitHub => Some("https://api.github.com/user"),
177            OAuth2Provider::Microsoft => Some("https://graph.microsoft.com/oidc/userinfo"),
178            OAuth2Provider::Facebook => Some("https://graph.facebook.com/me?fields=id,name,email"),
179            _ => None,
180        }
181    }
182
183    /// Get default scopes for this provider
184    pub fn default_scopes(&self) -> Vec<&'static str> {
185        match self {
186            OAuth2Provider::Google => vec!["openid", "email", "profile"],
187            OAuth2Provider::GitHub => vec!["read:user", "user:email"],
188            OAuth2Provider::Microsoft => vec!["openid", "email", "profile"],
189            OAuth2Provider::Facebook => vec!["email", "public_profile"],
190            OAuth2Provider::Apple => vec!["openid", "email", "name"],
191            OAuth2Provider::Okta => vec!["openid", "email", "profile"],
192            OAuth2Provider::Auth0 => vec!["openid", "email", "profile"],
193            OAuth2Provider::Keycloak => vec!["openid", "email", "profile"],
194            OAuth2Provider::Custom => vec!["openid"],
195        }
196    }
197
198    /// Check if this provider supports OIDC
199    pub fn supports_oidc(&self) -> bool {
200        matches!(
201            self,
202            OAuth2Provider::Google
203                | OAuth2Provider::Microsoft
204                | OAuth2Provider::Apple
205                | OAuth2Provider::Okta
206                | OAuth2Provider::Auth0
207                | OAuth2Provider::Keycloak
208        )
209    }
210}
211
212/// OAuth2 configuration for a client registration
213///
214/// Similar to Spring Security's `ClientRegistration`.
215///
216/// # Example
217///
218/// ```rust,ignore
219/// let config = OAuth2Config::new(
220///     "client-id",
221///     "client-secret",
222///     "http://localhost:8080/oauth2/callback/google"
223/// )
224/// .provider(OAuth2Provider::Google)
225/// .scopes(vec!["openid", "email", "profile"]);
226/// ```
227#[derive(Debug, Clone)]
228pub struct OAuth2Config {
229    /// Registration ID (e.g., "google", "github")
230    pub registration_id: String,
231    /// OAuth2 client ID
232    pub client_id: String,
233    /// OAuth2 client secret
234    pub client_secret: String,
235    /// Redirect URI for callbacks
236    pub redirect_uri: String,
237    /// OAuth2 provider
238    pub provider: OAuth2Provider,
239    /// Authorization endpoint URL (optional, auto-discovered for OIDC)
240    pub authorization_uri: Option<String>,
241    /// Token endpoint URL (optional, auto-discovered for OIDC)
242    pub token_uri: Option<String>,
243    /// User info endpoint URL (optional, auto-discovered for OIDC)
244    pub userinfo_uri: Option<String>,
245    /// OIDC issuer URL (for discovery)
246    pub issuer_uri: Option<String>,
247    /// JWK Set URI (for ID token validation)
248    pub jwk_set_uri: Option<String>,
249    /// OAuth2 scopes
250    pub scopes: Vec<String>,
251    /// Use PKCE (Proof Key for Code Exchange)
252    pub use_pkce: bool,
253    /// Custom parameters for authorization request
254    pub authorization_params: HashMap<String, String>,
255    /// Attribute name for username extraction
256    pub username_attribute: String,
257}
258
259impl OAuth2Config {
260    /// Create a new OAuth2 configuration
261    ///
262    /// # Arguments
263    ///
264    /// * `client_id` - The OAuth2 client ID
265    /// * `client_secret` - The OAuth2 client secret
266    /// * `redirect_uri` - The callback URL for authorization response
267    pub fn new(
268        client_id: impl Into<String>,
269        client_secret: impl Into<String>,
270        redirect_uri: impl Into<String>,
271    ) -> Self {
272        Self {
273            registration_id: String::new(),
274            client_id: client_id.into(),
275            client_secret: client_secret.into(),
276            redirect_uri: redirect_uri.into(),
277            provider: OAuth2Provider::Custom,
278            authorization_uri: None,
279            token_uri: None,
280            userinfo_uri: None,
281            issuer_uri: None,
282            jwk_set_uri: None,
283            scopes: vec!["openid".to_string()],
284            use_pkce: true,
285            authorization_params: HashMap::new(),
286            username_attribute: "sub".to_string(),
287        }
288    }
289
290    /// Set the registration ID
291    pub fn registration_id(mut self, id: impl Into<String>) -> Self {
292        self.registration_id = id.into();
293        self
294    }
295
296    /// Set the OAuth2 provider
297    ///
298    /// This will auto-configure endpoints for known providers.
299    pub fn provider(mut self, provider: OAuth2Provider) -> Self {
300        self.provider = provider;
301        if self.registration_id.is_empty() {
302            self.registration_id = format!("{:?}", provider).to_lowercase();
303        }
304
305        // Set default endpoints from provider
306        if let Some(auth_url) = provider.auth_url() {
307            self.authorization_uri = Some(auth_url.to_string());
308        }
309        if let Some(token_url) = provider.token_url() {
310            self.token_uri = Some(token_url.to_string());
311        }
312        if let Some(userinfo_url) = provider.userinfo_url() {
313            self.userinfo_uri = Some(userinfo_url.to_string());
314        }
315
316        // Set default scopes
317        if self.scopes.len() == 1 && self.scopes[0] == "openid" {
318            self.scopes = provider
319                .default_scopes()
320                .into_iter()
321                .map(String::from)
322                .collect();
323        }
324
325        // GitHub doesn't support PKCE
326        if matches!(provider, OAuth2Provider::GitHub | OAuth2Provider::Facebook) {
327            self.use_pkce = false;
328        }
329
330        self
331    }
332
333    /// Set the authorization endpoint URL
334    pub fn authorization_uri(mut self, uri: impl Into<String>) -> Self {
335        self.authorization_uri = Some(uri.into());
336        self
337    }
338
339    /// Set the token endpoint URL
340    pub fn token_uri(mut self, uri: impl Into<String>) -> Self {
341        self.token_uri = Some(uri.into());
342        self
343    }
344
345    /// Set the user info endpoint URL
346    pub fn userinfo_uri(mut self, uri: impl Into<String>) -> Self {
347        self.userinfo_uri = Some(uri.into());
348        self
349    }
350
351    /// Set the OIDC issuer URL for auto-discovery
352    pub fn issuer_uri(mut self, uri: impl Into<String>) -> Self {
353        self.issuer_uri = Some(uri.into());
354        self
355    }
356
357    /// Set the JWK Set URI for ID token validation
358    pub fn jwk_set_uri(mut self, uri: impl Into<String>) -> Self {
359        self.jwk_set_uri = Some(uri.into());
360        self
361    }
362
363    /// Set the OAuth2 scopes
364    pub fn scopes(mut self, scopes: Vec<impl Into<String>>) -> Self {
365        self.scopes = scopes.into_iter().map(|s| s.into()).collect();
366        self
367    }
368
369    /// Add a scope
370    pub fn add_scope(mut self, scope: impl Into<String>) -> Self {
371        self.scopes.push(scope.into());
372        self
373    }
374
375    /// Enable or disable PKCE
376    pub fn use_pkce(mut self, use_pkce: bool) -> Self {
377        self.use_pkce = use_pkce;
378        self
379    }
380
381    /// Add a custom authorization parameter
382    pub fn authorization_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
383        self.authorization_params.insert(key.into(), value.into());
384        self
385    }
386
387    /// Set the attribute name used for extracting the username
388    pub fn username_attribute(mut self, attr: impl Into<String>) -> Self {
389        self.username_attribute = attr.into();
390        self
391    }
392}
393
394/// User information retrieved from OAuth2 provider
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct OAuth2User {
397    /// User's unique identifier (subject)
398    pub sub: String,
399    /// User's name
400    pub name: Option<String>,
401    /// User's given name
402    pub given_name: Option<String>,
403    /// User's family name
404    pub family_name: Option<String>,
405    /// User's email
406    pub email: Option<String>,
407    /// Whether email is verified
408    pub email_verified: Option<bool>,
409    /// User's picture URL
410    pub picture: Option<String>,
411    /// User's locale
412    pub locale: Option<String>,
413    /// Provider-specific attributes
414    #[serde(flatten)]
415    pub attributes: HashMap<String, serde_json::Value>,
416    /// OAuth2 access token
417    #[serde(skip)]
418    pub access_token: Option<String>,
419    /// OAuth2 refresh token
420    #[serde(skip)]
421    pub refresh_token: Option<String>,
422    /// Token expiration time (Unix timestamp)
423    pub expires_at: Option<i64>,
424    /// Provider that authenticated this user
425    pub provider: String,
426}
427
428impl OAuth2User {
429    /// Create a new OAuth2User with minimal information
430    pub fn new(sub: impl Into<String>, provider: impl Into<String>) -> Self {
431        Self {
432            sub: sub.into(),
433            name: None,
434            given_name: None,
435            family_name: None,
436            email: None,
437            email_verified: None,
438            picture: None,
439            locale: None,
440            attributes: HashMap::new(),
441            access_token: None,
442            refresh_token: None,
443            expires_at: None,
444            provider: provider.into(),
445        }
446    }
447
448    /// Get a specific attribute value
449    pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
450        self.attributes.get(key)
451    }
452
453    /// Get the username (tries email first, then sub)
454    pub fn username(&self) -> &str {
455        self.email.as_deref().unwrap_or(&self.sub)
456    }
457
458    /// Convert to a User for authentication
459    pub fn to_user(&self) -> User {
460        User::new(self.username().to_string(), String::new())
461            .roles(&["USER".to_string()])
462            .authorities(&[format!("OAUTH2_USER_{}", self.provider.to_uppercase())])
463    }
464}
465
466/// OIDC user with ID token claims
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct OidcUser {
469    /// Base OAuth2 user info
470    #[serde(flatten)]
471    pub oauth2_user: OAuth2User,
472    /// ID token claims
473    pub id_token_claims: Option<IdTokenClaims>,
474    /// Raw ID token (JWT)
475    #[serde(skip)]
476    pub id_token: Option<String>,
477}
478
479/// ID Token claims
480#[derive(Debug, Clone, Serialize, Deserialize)]
481pub struct IdTokenClaims {
482    /// Issuer
483    pub iss: String,
484    /// Subject
485    pub sub: String,
486    /// Audience
487    pub aud: Vec<String>,
488    /// Expiration time
489    pub exp: i64,
490    /// Issued at time
491    pub iat: i64,
492    /// Authentication time
493    pub auth_time: Option<i64>,
494    /// Nonce
495    pub nonce: Option<String>,
496    /// Access token hash
497    pub at_hash: Option<String>,
498}
499
500impl OidcUser {
501    /// Convert to a User for authentication
502    pub fn to_user(&self) -> User {
503        self.oauth2_user.to_user()
504    }
505}
506
507/// Authorization request state (stored in session)
508#[derive(Debug, Clone, Serialize, Deserialize)]
509pub struct AuthorizationRequestState {
510    /// CSRF state token
511    pub state: String,
512    /// PKCE code verifier (if PKCE is used)
513    pub pkce_verifier: Option<String>,
514    /// OIDC nonce (if OIDC is used)
515    pub nonce: Option<String>,
516    /// Redirect URL after successful authentication
517    pub redirect_uri: Option<String>,
518    /// Provider registration ID
519    pub registration_id: String,
520    /// Timestamp when the request was created
521    pub created_at: i64,
522}
523
524/// OAuth2 client for handling authorization flows
525///
526/// Similar to Spring Security's `OAuth2AuthorizedClientService`.
527#[derive(Clone)]
528pub struct OAuth2Client {
529    config: OAuth2Config,
530    oauth2_client: BasicClient,
531    oidc_client: Option<Arc<CoreClient>>,
532}
533
534impl OAuth2Client {
535    /// Create a new OAuth2 client from configuration
536    ///
537    /// For OIDC providers, this will perform discovery to fetch provider metadata.
538    pub async fn new(config: OAuth2Config) -> Result<Self, OAuth2Error> {
539        // Build the basic OAuth2 client
540        let auth_url = config
541            .authorization_uri
542            .as_ref()
543            .ok_or_else(|| OAuth2Error::Configuration("Missing authorization URI".to_string()))?;
544
545        let token_url = config
546            .token_uri
547            .as_ref()
548            .ok_or_else(|| OAuth2Error::Configuration("Missing token URI".to_string()))?;
549
550        let oauth2_client = BasicClient::new(
551            ClientId::new(config.client_id.clone()),
552            Some(ClientSecret::new(config.client_secret.clone())),
553            AuthUrl::new(auth_url.clone())
554                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
555            Some(
556                TokenUrl::new(token_url.clone())
557                    .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
558            ),
559        )
560        .set_redirect_uri(
561            RedirectUrl::new(config.redirect_uri.clone())
562                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
563        );
564
565        // For OIDC providers, try to create an OIDC client
566        let oidc_client = if config.provider.supports_oidc() {
567            if let Some(issuer_uri) = &config.issuer_uri {
568                match Self::create_oidc_client(&config, issuer_uri).await {
569                    Ok(client) => Some(Arc::new(client)),
570                    Err(e) => {
571                        // Log warning but continue without OIDC
572                        eprintln!("Warning: OIDC discovery failed: {}", e);
573                        None
574                    }
575                }
576            } else if let Some(discovery_url) = config.provider.discovery_url() {
577                // Extract issuer from discovery URL
578                let issuer = discovery_url.trim_end_matches("/.well-known/openid-configuration");
579                match Self::create_oidc_client(&config, issuer).await {
580                    Ok(client) => Some(Arc::new(client)),
581                    Err(e) => {
582                        eprintln!("Warning: OIDC discovery failed: {}", e);
583                        None
584                    }
585                }
586            } else {
587                None
588            }
589        } else {
590            None
591        };
592
593        Ok(Self {
594            config,
595            oauth2_client,
596            oidc_client,
597        })
598    }
599
600    /// Create a new OAuth2 client without OIDC discovery (sync)
601    ///
602    /// Use this when you don't need OIDC features or want to avoid async initialization.
603    pub fn new_basic(config: OAuth2Config) -> Result<Self, OAuth2Error> {
604        let auth_url = config
605            .authorization_uri
606            .as_ref()
607            .ok_or_else(|| OAuth2Error::Configuration("Missing authorization URI".to_string()))?;
608
609        let token_url = config
610            .token_uri
611            .as_ref()
612            .ok_or_else(|| OAuth2Error::Configuration("Missing token URI".to_string()))?;
613
614        let oauth2_client = BasicClient::new(
615            ClientId::new(config.client_id.clone()),
616            Some(ClientSecret::new(config.client_secret.clone())),
617            AuthUrl::new(auth_url.clone())
618                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
619            Some(
620                TokenUrl::new(token_url.clone())
621                    .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
622            ),
623        )
624        .set_redirect_uri(
625            RedirectUrl::new(config.redirect_uri.clone())
626                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
627        );
628
629        Ok(Self {
630            config,
631            oauth2_client,
632            oidc_client: None,
633        })
634    }
635
636    /// Create an OIDC client with discovery
637    async fn create_oidc_client(
638        config: &OAuth2Config,
639        issuer_uri: &str,
640    ) -> Result<CoreClient, OAuth2Error> {
641        let issuer_url = IssuerUrl::new(issuer_uri.to_string())
642            .map_err(|e| OAuth2Error::Configuration(e.to_string()))?;
643
644        // Discover provider metadata using openidconnect's async http client
645        let provider_metadata = CoreProviderMetadata::discover_async(
646            issuer_url,
647            openidconnect::reqwest::async_http_client,
648        )
649        .await
650        .map_err(|e| OAuth2Error::Discovery(format!("{:?}", e)))?;
651
652        // Build OIDC client
653        let client = CoreClient::from_provider_metadata(
654            provider_metadata,
655            OidcClientId::new(config.client_id.clone()),
656            Some(OidcClientSecret::new(config.client_secret.clone())),
657        )
658        .set_redirect_uri(
659            OidcRedirectUrl::new(config.redirect_uri.clone())
660                .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
661        );
662
663        Ok(client)
664    }
665
666    /// Generate an authorization URL for the OAuth2 flow
667    ///
668    /// Returns (authorization_url, state, pkce_verifier, nonce)
669    pub fn authorization_url(&self) -> (Url, CsrfToken, Option<PkceCodeVerifier>, Option<Nonce>) {
670        if let Some(oidc_client) = &self.oidc_client {
671            // OIDC flow with nonce
672            let mut auth_request = oidc_client.authorize_url(
673                CoreAuthenticationFlow::AuthorizationCode,
674                CsrfToken::new_random,
675                Nonce::new_random,
676            );
677
678            // Add scopes
679            for scope in &self.config.scopes {
680                auth_request = auth_request.add_scope(openidconnect::Scope::new(scope.clone()));
681            }
682
683            // Add PKCE if enabled
684            let pkce_verifier = if self.config.use_pkce {
685                let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
686                auth_request = auth_request.set_pkce_challenge(pkce_challenge);
687                Some(pkce_verifier)
688            } else {
689                None
690            };
691
692            let (url, state, nonce) = auth_request.url();
693            (url, state, pkce_verifier, Some(nonce))
694        } else {
695            // Standard OAuth2 flow
696            let mut auth_request = self.oauth2_client.authorize_url(CsrfToken::new_random);
697
698            // Add scopes
699            for scope in &self.config.scopes {
700                auth_request = auth_request.add_scope(Scope::new(scope.clone()));
701            }
702
703            // Add PKCE if enabled
704            let pkce_verifier = if self.config.use_pkce {
705                let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
706                auth_request = auth_request.set_pkce_challenge(pkce_challenge);
707                Some(pkce_verifier)
708            } else {
709                None
710            };
711
712            let (url, state) = auth_request.url();
713            (url, state, pkce_verifier, None)
714        }
715    }
716
717    /// Exchange authorization code for tokens
718    pub async fn exchange_code(
719        &self,
720        code: &str,
721        pkce_verifier: Option<PkceCodeVerifier>,
722        nonce: Option<&Nonce>,
723    ) -> Result<(OAuth2User, Option<OidcUser>), OAuth2Error> {
724        if let Some(oidc_client) = &self.oidc_client {
725            // OIDC token exchange using openidconnect's async client
726            let mut token_request =
727                oidc_client.exchange_code(AuthorizationCode::new(code.to_string()));
728
729            if let Some(verifier) = pkce_verifier {
730                token_request = token_request.set_pkce_verifier(verifier);
731            }
732
733            let token_response = token_request
734                .request_async(openidconnect::reqwest::async_http_client)
735                .await
736                .map_err(|e| OAuth2Error::TokenExchange(format!("{:?}", e)))?;
737
738            // Verify and extract ID token claims
739            let id_token = token_response
740                .id_token()
741                .ok_or_else(|| OAuth2Error::TokenValidation("Missing ID token".to_string()))?;
742
743            let id_token_verifier = oidc_client.id_token_verifier();
744            let nonce_ref = nonce.cloned().unwrap_or_else(|| Nonce::new(String::new()));
745            let claims = id_token.claims(&id_token_verifier, &nonce_ref).map_err(
746                |e: openidconnect::ClaimsVerificationError| {
747                    OAuth2Error::TokenValidation(e.to_string())
748                },
749            )?;
750
751            // Extract basic info from claims
752            let subject = claims.subject().as_str().to_string();
753            let issuer = claims.issuer().as_str().to_string();
754            let exp = claims.expiration().timestamp();
755            let iat = claims.issue_time().timestamp();
756
757            // Build OAuth2User
758            let mut oauth2_user = OAuth2User::new(&subject, &self.config.registration_id);
759            oauth2_user.access_token = Some(token_response.access_token().secret().clone());
760            oauth2_user.refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
761            oauth2_user.email_verified = claims.email_verified();
762
763            // Build OidcUser with basic claims
764            let oidc_user = OidcUser {
765                oauth2_user: oauth2_user.clone(),
766                id_token_claims: Some(IdTokenClaims {
767                    iss: issuer,
768                    sub: subject,
769                    aud: vec![self.config.client_id.clone()],
770                    exp,
771                    iat,
772                    auth_time: None,
773                    nonce: None,
774                    at_hash: None,
775                }),
776                id_token: Some(id_token.to_string()),
777            };
778
779            Ok((oauth2_user, Some(oidc_user)))
780        } else {
781            // Standard OAuth2 token exchange
782            let mut token_request = self
783                .oauth2_client
784                .exchange_code(AuthorizationCode::new(code.to_string()));
785
786            if let Some(verifier) = pkce_verifier {
787                token_request = token_request.set_pkce_verifier(verifier);
788            }
789
790            let token_response = token_request
791                .request_async(async_http_client)
792                .await
793                .map_err(|e| OAuth2Error::TokenExchange(e.to_string()))?;
794
795            // Fetch user info
796            let mut oauth2_user = self
797                .fetch_user_info(token_response.access_token().secret())
798                .await?;
799
800            oauth2_user.access_token = Some(token_response.access_token().secret().clone());
801            oauth2_user.refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
802
803            Ok((oauth2_user, None))
804        }
805    }
806
807    /// Fetch user info from the provider's userinfo endpoint
808    async fn fetch_user_info(&self, access_token: &str) -> Result<OAuth2User, OAuth2Error> {
809        let userinfo_url = self
810            .config
811            .userinfo_uri
812            .as_ref()
813            .ok_or_else(|| OAuth2Error::UserInfo("Missing userinfo URI".to_string()))?;
814
815        let http_client = reqwest::Client::new();
816        let response = http_client
817            .get(userinfo_url)
818            .bearer_auth(access_token)
819            .header("Accept", "application/json")
820            .send()
821            .await
822            .map_err(|e| OAuth2Error::UserInfo(e.to_string()))?;
823
824        if !response.status().is_success() {
825            return Err(OAuth2Error::UserInfo(format!(
826                "HTTP {}: {}",
827                response.status(),
828                response.text().await.unwrap_or_default()
829            )));
830        }
831
832        let attributes: HashMap<String, serde_json::Value> = response
833            .json()
834            .await
835            .map_err(|e| OAuth2Error::UserInfo(e.to_string()))?;
836
837        // Extract user ID based on provider
838        let sub = self.extract_user_id(&attributes)?;
839        let mut user = OAuth2User::new(sub, &self.config.registration_id);
840        user.access_token = Some(access_token.to_string());
841        user.attributes = attributes.clone();
842
843        // Extract common fields
844        user.name = attributes
845            .get("name")
846            .and_then(|v| v.as_str())
847            .map(String::from);
848        user.email = attributes
849            .get("email")
850            .and_then(|v| v.as_str())
851            .map(String::from);
852        user.picture = attributes
853            .get("picture")
854            .or_else(|| attributes.get("avatar_url"))
855            .and_then(|v| v.as_str())
856            .map(String::from);
857
858        Ok(user)
859    }
860
861    /// Extract user ID from attributes based on provider
862    fn extract_user_id(
863        &self,
864        attributes: &HashMap<String, serde_json::Value>,
865    ) -> Result<String, OAuth2Error> {
866        // Try common ID fields
867        let id_fields = ["sub", "id", "user_id", "login"];
868
869        for field in &id_fields {
870            if let Some(value) = attributes.get(*field) {
871                if let Some(s) = value.as_str() {
872                    return Ok(s.to_string());
873                }
874                if let Some(n) = value.as_i64() {
875                    return Ok(n.to_string());
876                }
877            }
878        }
879
880        Err(OAuth2Error::UserInfo(
881            "Could not extract user ID".to_string(),
882        ))
883    }
884
885    /// Get the configuration
886    pub fn config(&self) -> &OAuth2Config {
887        &self.config
888    }
889
890    /// Check if OIDC is available
891    pub fn has_oidc(&self) -> bool {
892        self.oidc_client.is_some()
893    }
894}
895
896/// Repository for multiple OAuth2 client registrations
897///
898/// Similar to Spring Security's `ClientRegistrationRepository`.
899#[derive(Clone, Default)]
900pub struct OAuth2ClientRepository {
901    clients: HashMap<String, OAuth2Client>,
902}
903
904impl OAuth2ClientRepository {
905    /// Create a new empty repository
906    pub fn new() -> Self {
907        Self {
908            clients: HashMap::new(),
909        }
910    }
911
912    /// Add a client registration
913    pub fn add_client(&mut self, client: OAuth2Client) {
914        self.clients
915            .insert(client.config.registration_id.clone(), client);
916    }
917
918    /// Get a client by registration ID
919    pub fn get_client(&self, registration_id: &str) -> Option<&OAuth2Client> {
920        self.clients.get(registration_id)
921    }
922
923    /// Get all registration IDs
924    pub fn registration_ids(&self) -> Vec<&String> {
925        self.clients.keys().collect()
926    }
927
928    /// Build a repository from multiple configurations
929    pub async fn from_configs(configs: Vec<OAuth2Config>) -> Result<Self, OAuth2Error> {
930        let mut repo = Self::new();
931        for config in configs {
932            let client = OAuth2Client::new(config).await?;
933            repo.add_client(client);
934        }
935        Ok(repo)
936    }
937}
938
939/// OAuth2 authenticator that validates OAuth2 access tokens
940///
941/// This authenticator checks for Bearer tokens in the Authorization header
942/// and validates them against the OAuth2 provider.
943#[derive(Clone)]
944pub struct OAuth2Authenticator {
945    /// Expected issuer for token validation
946    issuer: Option<String>,
947    /// JWKS for token validation
948    jwks_uri: Option<String>,
949    /// Attribute to use as username
950    username_attribute: String,
951}
952
953impl OAuth2Authenticator {
954    /// Create a new OAuth2 authenticator
955    pub fn new() -> Self {
956        Self {
957            issuer: None,
958            jwks_uri: None,
959            username_attribute: "sub".to_string(),
960        }
961    }
962
963    /// Set the expected issuer
964    pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
965        self.issuer = Some(issuer.into());
966        self
967    }
968
969    /// Set the JWKS URI for token validation
970    pub fn jwks_uri(mut self, uri: impl Into<String>) -> Self {
971        self.jwks_uri = Some(uri.into());
972        self
973    }
974
975    /// Set the attribute to use as username
976    pub fn username_attribute(mut self, attr: impl Into<String>) -> Self {
977        self.username_attribute = attr.into();
978        self
979    }
980
981    /// Extract Bearer token from request
982    fn extract_token(&self, req: &ServiceRequest) -> Option<String> {
983        let auth_header = req.headers().get(AUTHORIZATION)?;
984        let auth_str = auth_header.to_str().ok()?;
985
986        auth_str
987            .strip_prefix("Bearer ")
988            .map(|token| token.to_string())
989    }
990}
991
992impl Default for OAuth2Authenticator {
993    fn default() -> Self {
994        Self::new()
995    }
996}
997
998impl Authenticator for OAuth2Authenticator {
999    fn get_user(&self, req: &ServiceRequest) -> Option<User> {
1000        // This is a simplified implementation
1001        // In production, you would validate the token against JWKS
1002        let _token = self.extract_token(req)?;
1003
1004        // For now, we return None as token validation requires async
1005        // The actual validation should be done via OAuth2CallbackHandler
1006        None
1007    }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012    use super::*;
1013
1014    #[test]
1015    fn test_oauth2_config_builder() {
1016        let config = OAuth2Config::new("client-id", "secret", "http://localhost/callback")
1017            .provider(OAuth2Provider::Google)
1018            .add_scope("custom_scope");
1019
1020        assert_eq!(config.client_id, "client-id");
1021        assert_eq!(config.provider, OAuth2Provider::Google);
1022        assert!(config.scopes.contains(&"openid".to_string()));
1023        assert!(config.scopes.contains(&"email".to_string()));
1024        assert!(config.scopes.contains(&"custom_scope".to_string()));
1025        assert!(config.use_pkce);
1026    }
1027
1028    #[test]
1029    fn test_oauth2_provider_endpoints() {
1030        assert!(OAuth2Provider::Google.auth_url().is_some());
1031        assert!(OAuth2Provider::Google.token_url().is_some());
1032        assert!(OAuth2Provider::Google.userinfo_url().is_some());
1033        assert!(OAuth2Provider::Google.supports_oidc());
1034
1035        assert!(OAuth2Provider::GitHub.auth_url().is_some());
1036        assert!(!OAuth2Provider::GitHub.supports_oidc());
1037    }
1038
1039    #[test]
1040    fn test_oauth2_user() {
1041        let mut user = OAuth2User::new("user123", "google");
1042        user.email = Some("user@example.com".to_string());
1043        user.name = Some("Test User".to_string());
1044
1045        assert_eq!(user.username(), "user@example.com");
1046
1047        let auth_user = user.to_user();
1048        assert_eq!(auth_user.get_username(), "user@example.com");
1049        assert!(auth_user.get_roles().contains(&"USER".to_string()));
1050    }
1051
1052    #[test]
1053    fn test_provider_default_scopes() {
1054        let google_scopes = OAuth2Provider::Google.default_scopes();
1055        assert!(google_scopes.contains(&"openid"));
1056        assert!(google_scopes.contains(&"email"));
1057
1058        let github_scopes = OAuth2Provider::GitHub.default_scopes();
1059        assert!(github_scopes.contains(&"read:user"));
1060    }
1061
1062    #[test]
1063    fn test_oauth2_client_basic() {
1064        let config = OAuth2Config::new("client-id", "secret", "http://localhost/callback")
1065            .provider(OAuth2Provider::GitHub);
1066
1067        let client = OAuth2Client::new_basic(config).unwrap();
1068        assert!(!client.has_oidc());
1069    }
1070}