1use std::collections::HashMap;
47use std::fmt;
48use std::sync::Arc;
49
50use actix_web::dev::ServiceRequest;
51use actix_web::http::header::AUTHORIZATION;
52use oauth2::basic::BasicClient;
53use oauth2::reqwest::async_http_client;
54use oauth2::{
55 AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
56 PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, TokenUrl,
57};
58use openidconnect::core::{
59 CoreAuthenticationFlow, CoreClient, CoreProviderMetadata,
60};
61use openidconnect::{
62 ClientId as OidcClientId, ClientSecret as OidcClientSecret, IssuerUrl, Nonce,
63 RedirectUrl as OidcRedirectUrl, TokenResponse as OidcTokenResponse,
64};
65use serde::{Deserialize, Serialize};
66use url::Url;
67
68use super::config::Authenticator;
69use super::user::User;
70
71#[derive(Debug, Clone)]
73pub enum OAuth2Error {
74 Configuration(String),
76 Discovery(String),
78 TokenExchange(String),
80 TokenValidation(String),
82 UserInfo(String),
84 InvalidState(String),
86 InvalidNonce(String),
88}
89
90impl fmt::Display for OAuth2Error {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 match self {
93 OAuth2Error::Configuration(msg) => write!(f, "Configuration error: {}", msg),
94 OAuth2Error::Discovery(msg) => write!(f, "Discovery error: {}", msg),
95 OAuth2Error::TokenExchange(msg) => write!(f, "Token exchange error: {}", msg),
96 OAuth2Error::TokenValidation(msg) => write!(f, "Token validation error: {}", msg),
97 OAuth2Error::UserInfo(msg) => write!(f, "User info error: {}", msg),
98 OAuth2Error::InvalidState(msg) => write!(f, "Invalid state: {}", msg),
99 OAuth2Error::InvalidNonce(msg) => write!(f, "Invalid nonce: {}", msg),
100 }
101 }
102}
103
104impl std::error::Error for OAuth2Error {}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
108pub enum OAuth2Provider {
109 Google,
111 GitHub,
113 Microsoft,
115 Facebook,
117 Apple,
119 Okta,
121 Auth0,
123 Keycloak,
125 Custom,
127}
128
129impl OAuth2Provider {
130 pub fn discovery_url(&self) -> Option<&'static str> {
132 match self {
133 OAuth2Provider::Google => {
134 Some("https://accounts.google.com/.well-known/openid-configuration")
135 }
136 OAuth2Provider::Microsoft => {
137 Some("https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration")
138 }
139 OAuth2Provider::Apple => {
140 Some("https://appleid.apple.com/.well-known/openid-configuration")
141 }
142 _ => None,
143 }
144 }
145
146 pub fn auth_url(&self) -> Option<&'static str> {
148 match self {
149 OAuth2Provider::Google => Some("https://accounts.google.com/o/oauth2/v2/auth"),
150 OAuth2Provider::GitHub => Some("https://github.com/login/oauth/authorize"),
151 OAuth2Provider::Microsoft => {
152 Some("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
153 }
154 OAuth2Provider::Facebook => Some("https://www.facebook.com/v18.0/dialog/oauth"),
155 OAuth2Provider::Apple => Some("https://appleid.apple.com/auth/authorize"),
156 _ => None,
157 }
158 }
159
160 pub fn token_url(&self) -> Option<&'static str> {
162 match self {
163 OAuth2Provider::Google => Some("https://oauth2.googleapis.com/token"),
164 OAuth2Provider::GitHub => Some("https://github.com/login/oauth/access_token"),
165 OAuth2Provider::Microsoft => {
166 Some("https://login.microsoftonline.com/common/oauth2/v2.0/token")
167 }
168 OAuth2Provider::Facebook => {
169 Some("https://graph.facebook.com/v18.0/oauth/access_token")
170 }
171 OAuth2Provider::Apple => Some("https://appleid.apple.com/auth/token"),
172 _ => None,
173 }
174 }
175
176 pub fn userinfo_url(&self) -> Option<&'static str> {
178 match self {
179 OAuth2Provider::Google => Some("https://openidconnect.googleapis.com/v1/userinfo"),
180 OAuth2Provider::GitHub => Some("https://api.github.com/user"),
181 OAuth2Provider::Microsoft => Some("https://graph.microsoft.com/oidc/userinfo"),
182 OAuth2Provider::Facebook => Some("https://graph.facebook.com/me?fields=id,name,email"),
183 _ => None,
184 }
185 }
186
187 pub fn default_scopes(&self) -> Vec<&'static str> {
189 match self {
190 OAuth2Provider::Google => vec!["openid", "email", "profile"],
191 OAuth2Provider::GitHub => vec!["read:user", "user:email"],
192 OAuth2Provider::Microsoft => vec!["openid", "email", "profile"],
193 OAuth2Provider::Facebook => vec!["email", "public_profile"],
194 OAuth2Provider::Apple => vec!["openid", "email", "name"],
195 OAuth2Provider::Okta => vec!["openid", "email", "profile"],
196 OAuth2Provider::Auth0 => vec!["openid", "email", "profile"],
197 OAuth2Provider::Keycloak => vec!["openid", "email", "profile"],
198 OAuth2Provider::Custom => vec!["openid"],
199 }
200 }
201
202 pub fn supports_oidc(&self) -> bool {
204 matches!(
205 self,
206 OAuth2Provider::Google
207 | OAuth2Provider::Microsoft
208 | OAuth2Provider::Apple
209 | OAuth2Provider::Okta
210 | OAuth2Provider::Auth0
211 | OAuth2Provider::Keycloak
212 )
213 }
214}
215
216#[derive(Debug, Clone)]
232pub struct OAuth2Config {
233 pub registration_id: String,
235 pub client_id: String,
237 pub client_secret: String,
239 pub redirect_uri: String,
241 pub provider: OAuth2Provider,
243 pub authorization_uri: Option<String>,
245 pub token_uri: Option<String>,
247 pub userinfo_uri: Option<String>,
249 pub issuer_uri: Option<String>,
251 pub jwk_set_uri: Option<String>,
253 pub scopes: Vec<String>,
255 pub use_pkce: bool,
257 pub authorization_params: HashMap<String, String>,
259 pub username_attribute: String,
261}
262
263impl OAuth2Config {
264 pub fn new(
272 client_id: impl Into<String>,
273 client_secret: impl Into<String>,
274 redirect_uri: impl Into<String>,
275 ) -> Self {
276 Self {
277 registration_id: String::new(),
278 client_id: client_id.into(),
279 client_secret: client_secret.into(),
280 redirect_uri: redirect_uri.into(),
281 provider: OAuth2Provider::Custom,
282 authorization_uri: None,
283 token_uri: None,
284 userinfo_uri: None,
285 issuer_uri: None,
286 jwk_set_uri: None,
287 scopes: vec!["openid".to_string()],
288 use_pkce: true,
289 authorization_params: HashMap::new(),
290 username_attribute: "sub".to_string(),
291 }
292 }
293
294 pub fn registration_id(mut self, id: impl Into<String>) -> Self {
296 self.registration_id = id.into();
297 self
298 }
299
300 pub fn provider(mut self, provider: OAuth2Provider) -> Self {
304 self.provider = provider;
305 if self.registration_id.is_empty() {
306 self.registration_id = format!("{:?}", provider).to_lowercase();
307 }
308
309 if let Some(auth_url) = provider.auth_url() {
311 self.authorization_uri = Some(auth_url.to_string());
312 }
313 if let Some(token_url) = provider.token_url() {
314 self.token_uri = Some(token_url.to_string());
315 }
316 if let Some(userinfo_url) = provider.userinfo_url() {
317 self.userinfo_uri = Some(userinfo_url.to_string());
318 }
319
320 if self.scopes.len() == 1 && self.scopes[0] == "openid" {
322 self.scopes = provider
323 .default_scopes()
324 .into_iter()
325 .map(String::from)
326 .collect();
327 }
328
329 if matches!(provider, OAuth2Provider::GitHub | OAuth2Provider::Facebook) {
331 self.use_pkce = false;
332 }
333
334 self
335 }
336
337 pub fn authorization_uri(mut self, uri: impl Into<String>) -> Self {
339 self.authorization_uri = Some(uri.into());
340 self
341 }
342
343 pub fn token_uri(mut self, uri: impl Into<String>) -> Self {
345 self.token_uri = Some(uri.into());
346 self
347 }
348
349 pub fn userinfo_uri(mut self, uri: impl Into<String>) -> Self {
351 self.userinfo_uri = Some(uri.into());
352 self
353 }
354
355 pub fn issuer_uri(mut self, uri: impl Into<String>) -> Self {
357 self.issuer_uri = Some(uri.into());
358 self
359 }
360
361 pub fn jwk_set_uri(mut self, uri: impl Into<String>) -> Self {
363 self.jwk_set_uri = Some(uri.into());
364 self
365 }
366
367 pub fn scopes(mut self, scopes: Vec<impl Into<String>>) -> Self {
369 self.scopes = scopes.into_iter().map(|s| s.into()).collect();
370 self
371 }
372
373 pub fn add_scope(mut self, scope: impl Into<String>) -> Self {
375 self.scopes.push(scope.into());
376 self
377 }
378
379 pub fn use_pkce(mut self, use_pkce: bool) -> Self {
381 self.use_pkce = use_pkce;
382 self
383 }
384
385 pub fn authorization_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
387 self.authorization_params.insert(key.into(), value.into());
388 self
389 }
390
391 pub fn username_attribute(mut self, attr: impl Into<String>) -> Self {
393 self.username_attribute = attr.into();
394 self
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct OAuth2User {
401 pub sub: String,
403 pub name: Option<String>,
405 pub given_name: Option<String>,
407 pub family_name: Option<String>,
409 pub email: Option<String>,
411 pub email_verified: Option<bool>,
413 pub picture: Option<String>,
415 pub locale: Option<String>,
417 #[serde(flatten)]
419 pub attributes: HashMap<String, serde_json::Value>,
420 #[serde(skip)]
422 pub access_token: Option<String>,
423 #[serde(skip)]
425 pub refresh_token: Option<String>,
426 pub expires_at: Option<i64>,
428 pub provider: String,
430}
431
432impl OAuth2User {
433 pub fn new(sub: impl Into<String>, provider: impl Into<String>) -> Self {
435 Self {
436 sub: sub.into(),
437 name: None,
438 given_name: None,
439 family_name: None,
440 email: None,
441 email_verified: None,
442 picture: None,
443 locale: None,
444 attributes: HashMap::new(),
445 access_token: None,
446 refresh_token: None,
447 expires_at: None,
448 provider: provider.into(),
449 }
450 }
451
452 pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
454 self.attributes.get(key)
455 }
456
457 pub fn username(&self) -> &str {
459 self.email.as_deref().unwrap_or(&self.sub)
460 }
461
462 pub fn to_user(&self) -> User {
464 User::new(self.username().to_string(), String::new())
465 .roles(&["USER".to_string()])
466 .authorities(&[format!("OAUTH2_USER_{}", self.provider.to_uppercase())])
467 }
468}
469
470#[derive(Debug, Clone, Serialize, Deserialize)]
472pub struct OidcUser {
473 #[serde(flatten)]
475 pub oauth2_user: OAuth2User,
476 pub id_token_claims: Option<IdTokenClaims>,
478 #[serde(skip)]
480 pub id_token: Option<String>,
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct IdTokenClaims {
486 pub iss: String,
488 pub sub: String,
490 pub aud: Vec<String>,
492 pub exp: i64,
494 pub iat: i64,
496 pub auth_time: Option<i64>,
498 pub nonce: Option<String>,
500 pub at_hash: Option<String>,
502}
503
504impl OidcUser {
505 pub fn to_user(&self) -> User {
507 self.oauth2_user.to_user()
508 }
509}
510
511#[derive(Debug, Clone, Serialize, Deserialize)]
513pub struct AuthorizationRequestState {
514 pub state: String,
516 pub pkce_verifier: Option<String>,
518 pub nonce: Option<String>,
520 pub redirect_uri: Option<String>,
522 pub registration_id: String,
524 pub created_at: i64,
526}
527
528#[derive(Clone)]
532pub struct OAuth2Client {
533 config: OAuth2Config,
534 oauth2_client: BasicClient,
535 oidc_client: Option<Arc<CoreClient>>,
536}
537
538impl OAuth2Client {
539 pub async fn new(config: OAuth2Config) -> Result<Self, OAuth2Error> {
543 let auth_url = config
545 .authorization_uri
546 .as_ref()
547 .ok_or_else(|| OAuth2Error::Configuration("Missing authorization URI".to_string()))?;
548
549 let token_url = config
550 .token_uri
551 .as_ref()
552 .ok_or_else(|| OAuth2Error::Configuration("Missing token URI".to_string()))?;
553
554 let oauth2_client = BasicClient::new(
555 ClientId::new(config.client_id.clone()),
556 Some(ClientSecret::new(config.client_secret.clone())),
557 AuthUrl::new(auth_url.clone())
558 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
559 Some(
560 TokenUrl::new(token_url.clone())
561 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
562 ),
563 )
564 .set_redirect_uri(
565 RedirectUrl::new(config.redirect_uri.clone())
566 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
567 );
568
569 let oidc_client = if config.provider.supports_oidc() {
571 if let Some(issuer_uri) = &config.issuer_uri {
572 match Self::create_oidc_client(&config, issuer_uri).await {
573 Ok(client) => Some(Arc::new(client)),
574 Err(e) => {
575 eprintln!("Warning: OIDC discovery failed: {}", e);
577 None
578 }
579 }
580 } else if let Some(discovery_url) = config.provider.discovery_url() {
581 let issuer = discovery_url.trim_end_matches("/.well-known/openid-configuration");
583 match Self::create_oidc_client(&config, issuer).await {
584 Ok(client) => Some(Arc::new(client)),
585 Err(e) => {
586 eprintln!("Warning: OIDC discovery failed: {}", e);
587 None
588 }
589 }
590 } else {
591 None
592 }
593 } else {
594 None
595 };
596
597 Ok(Self {
598 config,
599 oauth2_client,
600 oidc_client,
601 })
602 }
603
604 pub fn new_basic(config: OAuth2Config) -> Result<Self, OAuth2Error> {
608 let auth_url = config
609 .authorization_uri
610 .as_ref()
611 .ok_or_else(|| OAuth2Error::Configuration("Missing authorization URI".to_string()))?;
612
613 let token_url = config
614 .token_uri
615 .as_ref()
616 .ok_or_else(|| OAuth2Error::Configuration("Missing token URI".to_string()))?;
617
618 let oauth2_client = BasicClient::new(
619 ClientId::new(config.client_id.clone()),
620 Some(ClientSecret::new(config.client_secret.clone())),
621 AuthUrl::new(auth_url.clone())
622 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
623 Some(
624 TokenUrl::new(token_url.clone())
625 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
626 ),
627 )
628 .set_redirect_uri(
629 RedirectUrl::new(config.redirect_uri.clone())
630 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
631 );
632
633 Ok(Self {
634 config,
635 oauth2_client,
636 oidc_client: None,
637 })
638 }
639
640 async fn create_oidc_client(
642 config: &OAuth2Config,
643 issuer_uri: &str,
644 ) -> Result<CoreClient, OAuth2Error> {
645 let issuer_url = IssuerUrl::new(issuer_uri.to_string())
646 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?;
647
648 let provider_metadata = CoreProviderMetadata::discover_async(
650 issuer_url,
651 openidconnect::reqwest::async_http_client,
652 )
653 .await
654 .map_err(|e| OAuth2Error::Discovery(format!("{:?}", e)))?;
655
656 let client = CoreClient::from_provider_metadata(
658 provider_metadata,
659 OidcClientId::new(config.client_id.clone()),
660 Some(OidcClientSecret::new(config.client_secret.clone())),
661 )
662 .set_redirect_uri(
663 OidcRedirectUrl::new(config.redirect_uri.clone())
664 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
665 );
666
667 Ok(client)
668 }
669
670 pub fn authorization_url(
674 &self,
675 ) -> (Url, CsrfToken, Option<PkceCodeVerifier>, Option<Nonce>) {
676 if let Some(oidc_client) = &self.oidc_client {
677 let mut auth_request = oidc_client.authorize_url(
679 CoreAuthenticationFlow::AuthorizationCode,
680 CsrfToken::new_random,
681 Nonce::new_random,
682 );
683
684 for scope in &self.config.scopes {
686 auth_request = auth_request.add_scope(openidconnect::Scope::new(scope.clone()));
687 }
688
689 let pkce_verifier = if self.config.use_pkce {
691 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
692 auth_request = auth_request.set_pkce_challenge(pkce_challenge);
693 Some(pkce_verifier)
694 } else {
695 None
696 };
697
698 let (url, state, nonce) = auth_request.url();
699 (url, state, pkce_verifier, Some(nonce))
700 } else {
701 let mut auth_request = self.oauth2_client.authorize_url(CsrfToken::new_random);
703
704 for scope in &self.config.scopes {
706 auth_request = auth_request.add_scope(Scope::new(scope.clone()));
707 }
708
709 let pkce_verifier = if self.config.use_pkce {
711 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
712 auth_request = auth_request.set_pkce_challenge(pkce_challenge);
713 Some(pkce_verifier)
714 } else {
715 None
716 };
717
718 let (url, state) = auth_request.url();
719 (url, state, pkce_verifier, None)
720 }
721 }
722
723 pub async fn exchange_code(
725 &self,
726 code: &str,
727 pkce_verifier: Option<PkceCodeVerifier>,
728 nonce: Option<&Nonce>,
729 ) -> Result<(OAuth2User, Option<OidcUser>), OAuth2Error> {
730 if let Some(oidc_client) = &self.oidc_client {
731 let mut token_request =
733 oidc_client.exchange_code(AuthorizationCode::new(code.to_string()));
734
735 if let Some(verifier) = pkce_verifier {
736 token_request = token_request.set_pkce_verifier(verifier);
737 }
738
739 let token_response = token_request
740 .request_async(openidconnect::reqwest::async_http_client)
741 .await
742 .map_err(|e| OAuth2Error::TokenExchange(format!("{:?}", e)))?;
743
744 let id_token = token_response
746 .id_token()
747 .ok_or_else(|| OAuth2Error::TokenValidation("Missing ID token".to_string()))?;
748
749 let id_token_verifier = oidc_client.id_token_verifier();
750 let nonce_ref = nonce.cloned().unwrap_or_else(|| Nonce::new(String::new()));
751 let claims = id_token
752 .claims(&id_token_verifier, &nonce_ref)
753 .map_err(|e: openidconnect::ClaimsVerificationError| {
754 OAuth2Error::TokenValidation(e.to_string())
755 })?;
756
757 let subject = claims.subject().as_str().to_string();
759 let issuer = claims.issuer().as_str().to_string();
760 let exp = claims.expiration().timestamp();
761 let iat = claims.issue_time().timestamp();
762
763 let mut oauth2_user = OAuth2User::new(&subject, &self.config.registration_id);
765 oauth2_user.access_token = Some(token_response.access_token().secret().clone());
766 oauth2_user.refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
767 oauth2_user.email_verified = claims.email_verified();
768
769 let oidc_user = OidcUser {
771 oauth2_user: oauth2_user.clone(),
772 id_token_claims: Some(IdTokenClaims {
773 iss: issuer,
774 sub: subject,
775 aud: vec![self.config.client_id.clone()],
776 exp,
777 iat,
778 auth_time: None,
779 nonce: None,
780 at_hash: None,
781 }),
782 id_token: Some(id_token.to_string()),
783 };
784
785 Ok((oauth2_user, Some(oidc_user)))
786 } else {
787 let mut token_request = self
789 .oauth2_client
790 .exchange_code(AuthorizationCode::new(code.to_string()));
791
792 if let Some(verifier) = pkce_verifier {
793 token_request = token_request.set_pkce_verifier(verifier);
794 }
795
796 let token_response = token_request
797 .request_async(async_http_client)
798 .await
799 .map_err(|e| OAuth2Error::TokenExchange(e.to_string()))?;
800
801 let mut oauth2_user = self
803 .fetch_user_info(token_response.access_token().secret())
804 .await?;
805
806 oauth2_user.access_token = Some(token_response.access_token().secret().clone());
807 oauth2_user.refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
808
809 Ok((oauth2_user, None))
810 }
811 }
812
813 async fn fetch_user_info(&self, access_token: &str) -> Result<OAuth2User, OAuth2Error> {
815 let userinfo_url = self
816 .config
817 .userinfo_uri
818 .as_ref()
819 .ok_or_else(|| OAuth2Error::UserInfo("Missing userinfo URI".to_string()))?;
820
821 let http_client = reqwest::Client::new();
822 let response = http_client
823 .get(userinfo_url)
824 .bearer_auth(access_token)
825 .header("Accept", "application/json")
826 .send()
827 .await
828 .map_err(|e| OAuth2Error::UserInfo(e.to_string()))?;
829
830 if !response.status().is_success() {
831 return Err(OAuth2Error::UserInfo(format!(
832 "HTTP {}: {}",
833 response.status(),
834 response.text().await.unwrap_or_default()
835 )));
836 }
837
838 let attributes: HashMap<String, serde_json::Value> = response
839 .json()
840 .await
841 .map_err(|e| OAuth2Error::UserInfo(e.to_string()))?;
842
843 let sub = self.extract_user_id(&attributes)?;
845 let mut user = OAuth2User::new(sub, &self.config.registration_id);
846 user.access_token = Some(access_token.to_string());
847 user.attributes = attributes.clone();
848
849 user.name = attributes
851 .get("name")
852 .and_then(|v| v.as_str())
853 .map(String::from);
854 user.email = attributes
855 .get("email")
856 .and_then(|v| v.as_str())
857 .map(String::from);
858 user.picture = attributes
859 .get("picture")
860 .or_else(|| attributes.get("avatar_url"))
861 .and_then(|v| v.as_str())
862 .map(String::from);
863
864 Ok(user)
865 }
866
867 fn extract_user_id(
869 &self,
870 attributes: &HashMap<String, serde_json::Value>,
871 ) -> Result<String, OAuth2Error> {
872 let id_fields = ["sub", "id", "user_id", "login"];
874
875 for field in &id_fields {
876 if let Some(value) = attributes.get(*field) {
877 if let Some(s) = value.as_str() {
878 return Ok(s.to_string());
879 }
880 if let Some(n) = value.as_i64() {
881 return Ok(n.to_string());
882 }
883 }
884 }
885
886 Err(OAuth2Error::UserInfo(
887 "Could not extract user ID".to_string(),
888 ))
889 }
890
891 pub fn config(&self) -> &OAuth2Config {
893 &self.config
894 }
895
896 pub fn has_oidc(&self) -> bool {
898 self.oidc_client.is_some()
899 }
900}
901
902#[derive(Clone, Default)]
906pub struct OAuth2ClientRepository {
907 clients: HashMap<String, OAuth2Client>,
908}
909
910impl OAuth2ClientRepository {
911 pub fn new() -> Self {
913 Self {
914 clients: HashMap::new(),
915 }
916 }
917
918 pub fn add_client(&mut self, client: OAuth2Client) {
920 self.clients
921 .insert(client.config.registration_id.clone(), client);
922 }
923
924 pub fn get_client(&self, registration_id: &str) -> Option<&OAuth2Client> {
926 self.clients.get(registration_id)
927 }
928
929 pub fn registration_ids(&self) -> Vec<&String> {
931 self.clients.keys().collect()
932 }
933
934 pub async fn from_configs(configs: Vec<OAuth2Config>) -> Result<Self, OAuth2Error> {
936 let mut repo = Self::new();
937 for config in configs {
938 let client = OAuth2Client::new(config).await?;
939 repo.add_client(client);
940 }
941 Ok(repo)
942 }
943}
944
945#[derive(Clone)]
950pub struct OAuth2Authenticator {
951 issuer: Option<String>,
953 jwks_uri: Option<String>,
955 username_attribute: String,
957}
958
959impl OAuth2Authenticator {
960 pub fn new() -> Self {
962 Self {
963 issuer: None,
964 jwks_uri: None,
965 username_attribute: "sub".to_string(),
966 }
967 }
968
969 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
971 self.issuer = Some(issuer.into());
972 self
973 }
974
975 pub fn jwks_uri(mut self, uri: impl Into<String>) -> Self {
977 self.jwks_uri = Some(uri.into());
978 self
979 }
980
981 pub fn username_attribute(mut self, attr: impl Into<String>) -> Self {
983 self.username_attribute = attr.into();
984 self
985 }
986
987 fn extract_token(&self, req: &ServiceRequest) -> Option<String> {
989 let auth_header = req.headers().get(AUTHORIZATION)?;
990 let auth_str = auth_header.to_str().ok()?;
991
992 auth_str
993 .strip_prefix("Bearer ")
994 .map(|token| token.to_string())
995 }
996}
997
998impl Default for OAuth2Authenticator {
999 fn default() -> Self {
1000 Self::new()
1001 }
1002}
1003
1004impl Authenticator for OAuth2Authenticator {
1005 fn get_user(&self, req: &ServiceRequest) -> Option<User> {
1006 let _token = self.extract_token(req)?;
1009
1010 None
1013 }
1014}
1015
1016#[cfg(test)]
1017mod tests {
1018 use super::*;
1019
1020 #[test]
1021 fn test_oauth2_config_builder() {
1022 let config = OAuth2Config::new("client-id", "secret", "http://localhost/callback")
1023 .provider(OAuth2Provider::Google)
1024 .add_scope("custom_scope");
1025
1026 assert_eq!(config.client_id, "client-id");
1027 assert_eq!(config.provider, OAuth2Provider::Google);
1028 assert!(config.scopes.contains(&"openid".to_string()));
1029 assert!(config.scopes.contains(&"email".to_string()));
1030 assert!(config.scopes.contains(&"custom_scope".to_string()));
1031 assert!(config.use_pkce);
1032 }
1033
1034 #[test]
1035 fn test_oauth2_provider_endpoints() {
1036 assert!(OAuth2Provider::Google.auth_url().is_some());
1037 assert!(OAuth2Provider::Google.token_url().is_some());
1038 assert!(OAuth2Provider::Google.userinfo_url().is_some());
1039 assert!(OAuth2Provider::Google.supports_oidc());
1040
1041 assert!(OAuth2Provider::GitHub.auth_url().is_some());
1042 assert!(!OAuth2Provider::GitHub.supports_oidc());
1043 }
1044
1045 #[test]
1046 fn test_oauth2_user() {
1047 let mut user = OAuth2User::new("user123", "google");
1048 user.email = Some("user@example.com".to_string());
1049 user.name = Some("Test User".to_string());
1050
1051 assert_eq!(user.username(), "user@example.com");
1052
1053 let auth_user = user.to_user();
1054 assert_eq!(auth_user.get_username(), "user@example.com");
1055 assert!(auth_user.get_roles().contains(&"USER".to_string()));
1056 }
1057
1058 #[test]
1059 fn test_provider_default_scopes() {
1060 let google_scopes = OAuth2Provider::Google.default_scopes();
1061 assert!(google_scopes.contains(&"openid"));
1062 assert!(google_scopes.contains(&"email"));
1063
1064 let github_scopes = OAuth2Provider::GitHub.default_scopes();
1065 assert!(github_scopes.contains(&"read:user"));
1066 }
1067
1068 #[test]
1069 fn test_oauth2_client_basic() {
1070 let config = OAuth2Config::new("client-id", "secret", "http://localhost/callback")
1071 .provider(OAuth2Provider::GitHub);
1072
1073 let client = OAuth2Client::new_basic(config).unwrap();
1074 assert!(!client.has_oidc());
1075 }
1076}