1use std::collections::HashMap;
47use std::fmt;
48use std::sync::Arc;
49
50use actix_web::dev::ServiceRequest;
51use actix_web::http::header::AUTHORIZATION;
52use oauth2::basic::BasicClient;
53use oauth2::reqwest::async_http_client;
54use oauth2::{
55 AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
56 PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, TokenUrl,
57};
58use openidconnect::core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata};
59use openidconnect::{
60 ClientId as OidcClientId, ClientSecret as OidcClientSecret, IssuerUrl, Nonce,
61 RedirectUrl as OidcRedirectUrl, TokenResponse as OidcTokenResponse,
62};
63use serde::{Deserialize, Serialize};
64use url::Url;
65
66use super::config::Authenticator;
67use super::user::User;
68
69#[derive(Debug, Clone)]
71pub enum OAuth2Error {
72 Configuration(String),
74 Discovery(String),
76 TokenExchange(String),
78 TokenValidation(String),
80 UserInfo(String),
82 InvalidState(String),
84 InvalidNonce(String),
86}
87
88impl fmt::Display for OAuth2Error {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 match self {
91 OAuth2Error::Configuration(msg) => write!(f, "Configuration error: {}", msg),
92 OAuth2Error::Discovery(msg) => write!(f, "Discovery error: {}", msg),
93 OAuth2Error::TokenExchange(msg) => write!(f, "Token exchange error: {}", msg),
94 OAuth2Error::TokenValidation(msg) => write!(f, "Token validation error: {}", msg),
95 OAuth2Error::UserInfo(msg) => write!(f, "User info error: {}", msg),
96 OAuth2Error::InvalidState(msg) => write!(f, "Invalid state: {}", msg),
97 OAuth2Error::InvalidNonce(msg) => write!(f, "Invalid nonce: {}", msg),
98 }
99 }
100}
101
102impl std::error::Error for OAuth2Error {}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
106pub enum OAuth2Provider {
107 Google,
109 GitHub,
111 Microsoft,
113 Facebook,
115 Apple,
117 Okta,
119 Auth0,
121 Keycloak,
123 Custom,
125}
126
127impl OAuth2Provider {
128 pub fn discovery_url(&self) -> Option<&'static str> {
130 match self {
131 OAuth2Provider::Google => {
132 Some("https://accounts.google.com/.well-known/openid-configuration")
133 }
134 OAuth2Provider::Microsoft => Some(
135 "https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration",
136 ),
137 OAuth2Provider::Apple => {
138 Some("https://appleid.apple.com/.well-known/openid-configuration")
139 }
140 _ => None,
141 }
142 }
143
144 pub fn auth_url(&self) -> Option<&'static str> {
146 match self {
147 OAuth2Provider::Google => Some("https://accounts.google.com/o/oauth2/v2/auth"),
148 OAuth2Provider::GitHub => Some("https://github.com/login/oauth/authorize"),
149 OAuth2Provider::Microsoft => {
150 Some("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
151 }
152 OAuth2Provider::Facebook => Some("https://www.facebook.com/v18.0/dialog/oauth"),
153 OAuth2Provider::Apple => Some("https://appleid.apple.com/auth/authorize"),
154 _ => None,
155 }
156 }
157
158 pub fn token_url(&self) -> Option<&'static str> {
160 match self {
161 OAuth2Provider::Google => Some("https://oauth2.googleapis.com/token"),
162 OAuth2Provider::GitHub => Some("https://github.com/login/oauth/access_token"),
163 OAuth2Provider::Microsoft => {
164 Some("https://login.microsoftonline.com/common/oauth2/v2.0/token")
165 }
166 OAuth2Provider::Facebook => Some("https://graph.facebook.com/v18.0/oauth/access_token"),
167 OAuth2Provider::Apple => Some("https://appleid.apple.com/auth/token"),
168 _ => None,
169 }
170 }
171
172 pub fn userinfo_url(&self) -> Option<&'static str> {
174 match self {
175 OAuth2Provider::Google => Some("https://openidconnect.googleapis.com/v1/userinfo"),
176 OAuth2Provider::GitHub => Some("https://api.github.com/user"),
177 OAuth2Provider::Microsoft => Some("https://graph.microsoft.com/oidc/userinfo"),
178 OAuth2Provider::Facebook => Some("https://graph.facebook.com/me?fields=id,name,email"),
179 _ => None,
180 }
181 }
182
183 pub fn default_scopes(&self) -> Vec<&'static str> {
185 match self {
186 OAuth2Provider::Google => vec!["openid", "email", "profile"],
187 OAuth2Provider::GitHub => vec!["read:user", "user:email"],
188 OAuth2Provider::Microsoft => vec!["openid", "email", "profile"],
189 OAuth2Provider::Facebook => vec!["email", "public_profile"],
190 OAuth2Provider::Apple => vec!["openid", "email", "name"],
191 OAuth2Provider::Okta => vec!["openid", "email", "profile"],
192 OAuth2Provider::Auth0 => vec!["openid", "email", "profile"],
193 OAuth2Provider::Keycloak => vec!["openid", "email", "profile"],
194 OAuth2Provider::Custom => vec!["openid"],
195 }
196 }
197
198 pub fn supports_oidc(&self) -> bool {
200 matches!(
201 self,
202 OAuth2Provider::Google
203 | OAuth2Provider::Microsoft
204 | OAuth2Provider::Apple
205 | OAuth2Provider::Okta
206 | OAuth2Provider::Auth0
207 | OAuth2Provider::Keycloak
208 )
209 }
210}
211
212#[derive(Debug, Clone)]
228pub struct OAuth2Config {
229 pub registration_id: String,
231 pub client_id: String,
233 pub client_secret: String,
235 pub redirect_uri: String,
237 pub provider: OAuth2Provider,
239 pub authorization_uri: Option<String>,
241 pub token_uri: Option<String>,
243 pub userinfo_uri: Option<String>,
245 pub issuer_uri: Option<String>,
247 pub jwk_set_uri: Option<String>,
249 pub scopes: Vec<String>,
251 pub use_pkce: bool,
253 pub authorization_params: HashMap<String, String>,
255 pub username_attribute: String,
257}
258
259impl OAuth2Config {
260 pub fn new(
268 client_id: impl Into<String>,
269 client_secret: impl Into<String>,
270 redirect_uri: impl Into<String>,
271 ) -> Self {
272 Self {
273 registration_id: String::new(),
274 client_id: client_id.into(),
275 client_secret: client_secret.into(),
276 redirect_uri: redirect_uri.into(),
277 provider: OAuth2Provider::Custom,
278 authorization_uri: None,
279 token_uri: None,
280 userinfo_uri: None,
281 issuer_uri: None,
282 jwk_set_uri: None,
283 scopes: vec!["openid".to_string()],
284 use_pkce: true,
285 authorization_params: HashMap::new(),
286 username_attribute: "sub".to_string(),
287 }
288 }
289
290 pub fn registration_id(mut self, id: impl Into<String>) -> Self {
292 self.registration_id = id.into();
293 self
294 }
295
296 pub fn provider(mut self, provider: OAuth2Provider) -> Self {
300 self.provider = provider;
301 if self.registration_id.is_empty() {
302 self.registration_id = format!("{:?}", provider).to_lowercase();
303 }
304
305 if let Some(auth_url) = provider.auth_url() {
307 self.authorization_uri = Some(auth_url.to_string());
308 }
309 if let Some(token_url) = provider.token_url() {
310 self.token_uri = Some(token_url.to_string());
311 }
312 if let Some(userinfo_url) = provider.userinfo_url() {
313 self.userinfo_uri = Some(userinfo_url.to_string());
314 }
315
316 if self.scopes.len() == 1 && self.scopes[0] == "openid" {
318 self.scopes = provider
319 .default_scopes()
320 .into_iter()
321 .map(String::from)
322 .collect();
323 }
324
325 if matches!(provider, OAuth2Provider::GitHub | OAuth2Provider::Facebook) {
327 self.use_pkce = false;
328 }
329
330 self
331 }
332
333 pub fn authorization_uri(mut self, uri: impl Into<String>) -> Self {
335 self.authorization_uri = Some(uri.into());
336 self
337 }
338
339 pub fn token_uri(mut self, uri: impl Into<String>) -> Self {
341 self.token_uri = Some(uri.into());
342 self
343 }
344
345 pub fn userinfo_uri(mut self, uri: impl Into<String>) -> Self {
347 self.userinfo_uri = Some(uri.into());
348 self
349 }
350
351 pub fn issuer_uri(mut self, uri: impl Into<String>) -> Self {
353 self.issuer_uri = Some(uri.into());
354 self
355 }
356
357 pub fn jwk_set_uri(mut self, uri: impl Into<String>) -> Self {
359 self.jwk_set_uri = Some(uri.into());
360 self
361 }
362
363 pub fn scopes(mut self, scopes: Vec<impl Into<String>>) -> Self {
365 self.scopes = scopes.into_iter().map(|s| s.into()).collect();
366 self
367 }
368
369 pub fn add_scope(mut self, scope: impl Into<String>) -> Self {
371 self.scopes.push(scope.into());
372 self
373 }
374
375 pub fn use_pkce(mut self, use_pkce: bool) -> Self {
377 self.use_pkce = use_pkce;
378 self
379 }
380
381 pub fn authorization_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
383 self.authorization_params.insert(key.into(), value.into());
384 self
385 }
386
387 pub fn username_attribute(mut self, attr: impl Into<String>) -> Self {
389 self.username_attribute = attr.into();
390 self
391 }
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct OAuth2User {
397 pub sub: String,
399 pub name: Option<String>,
401 pub given_name: Option<String>,
403 pub family_name: Option<String>,
405 pub email: Option<String>,
407 pub email_verified: Option<bool>,
409 pub picture: Option<String>,
411 pub locale: Option<String>,
413 #[serde(flatten)]
415 pub attributes: HashMap<String, serde_json::Value>,
416 #[serde(skip)]
418 pub access_token: Option<String>,
419 #[serde(skip)]
421 pub refresh_token: Option<String>,
422 pub expires_at: Option<i64>,
424 pub provider: String,
426}
427
428impl OAuth2User {
429 pub fn new(sub: impl Into<String>, provider: impl Into<String>) -> Self {
431 Self {
432 sub: sub.into(),
433 name: None,
434 given_name: None,
435 family_name: None,
436 email: None,
437 email_verified: None,
438 picture: None,
439 locale: None,
440 attributes: HashMap::new(),
441 access_token: None,
442 refresh_token: None,
443 expires_at: None,
444 provider: provider.into(),
445 }
446 }
447
448 pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
450 self.attributes.get(key)
451 }
452
453 pub fn username(&self) -> &str {
455 self.email.as_deref().unwrap_or(&self.sub)
456 }
457
458 pub fn to_user(&self) -> User {
460 User::new(self.username().to_string(), String::new())
461 .roles(&["USER".to_string()])
462 .authorities(&[format!("OAUTH2_USER_{}", self.provider.to_uppercase())])
463 }
464}
465
466#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct OidcUser {
469 #[serde(flatten)]
471 pub oauth2_user: OAuth2User,
472 pub id_token_claims: Option<IdTokenClaims>,
474 #[serde(skip)]
476 pub id_token: Option<String>,
477}
478
479#[derive(Debug, Clone, Serialize, Deserialize)]
481pub struct IdTokenClaims {
482 pub iss: String,
484 pub sub: String,
486 pub aud: Vec<String>,
488 pub exp: i64,
490 pub iat: i64,
492 pub auth_time: Option<i64>,
494 pub nonce: Option<String>,
496 pub at_hash: Option<String>,
498}
499
500impl OidcUser {
501 pub fn to_user(&self) -> User {
503 self.oauth2_user.to_user()
504 }
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
509pub struct AuthorizationRequestState {
510 pub state: String,
512 pub pkce_verifier: Option<String>,
514 pub nonce: Option<String>,
516 pub redirect_uri: Option<String>,
518 pub registration_id: String,
520 pub created_at: i64,
522}
523
524#[derive(Clone)]
528pub struct OAuth2Client {
529 config: OAuth2Config,
530 oauth2_client: BasicClient,
531 oidc_client: Option<Arc<CoreClient>>,
532}
533
534impl OAuth2Client {
535 pub async fn new(config: OAuth2Config) -> Result<Self, OAuth2Error> {
539 let auth_url = config
541 .authorization_uri
542 .as_ref()
543 .ok_or_else(|| OAuth2Error::Configuration("Missing authorization URI".to_string()))?;
544
545 let token_url = config
546 .token_uri
547 .as_ref()
548 .ok_or_else(|| OAuth2Error::Configuration("Missing token URI".to_string()))?;
549
550 let oauth2_client = BasicClient::new(
551 ClientId::new(config.client_id.clone()),
552 Some(ClientSecret::new(config.client_secret.clone())),
553 AuthUrl::new(auth_url.clone())
554 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
555 Some(
556 TokenUrl::new(token_url.clone())
557 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
558 ),
559 )
560 .set_redirect_uri(
561 RedirectUrl::new(config.redirect_uri.clone())
562 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
563 );
564
565 let oidc_client = if config.provider.supports_oidc() {
567 if let Some(issuer_uri) = &config.issuer_uri {
568 match Self::create_oidc_client(&config, issuer_uri).await {
569 Ok(client) => Some(Arc::new(client)),
570 Err(e) => {
571 eprintln!("Warning: OIDC discovery failed: {}", e);
573 None
574 }
575 }
576 } else if let Some(discovery_url) = config.provider.discovery_url() {
577 let issuer = discovery_url.trim_end_matches("/.well-known/openid-configuration");
579 match Self::create_oidc_client(&config, issuer).await {
580 Ok(client) => Some(Arc::new(client)),
581 Err(e) => {
582 eprintln!("Warning: OIDC discovery failed: {}", e);
583 None
584 }
585 }
586 } else {
587 None
588 }
589 } else {
590 None
591 };
592
593 Ok(Self {
594 config,
595 oauth2_client,
596 oidc_client,
597 })
598 }
599
600 pub fn new_basic(config: OAuth2Config) -> Result<Self, OAuth2Error> {
604 let auth_url = config
605 .authorization_uri
606 .as_ref()
607 .ok_or_else(|| OAuth2Error::Configuration("Missing authorization URI".to_string()))?;
608
609 let token_url = config
610 .token_uri
611 .as_ref()
612 .ok_or_else(|| OAuth2Error::Configuration("Missing token URI".to_string()))?;
613
614 let oauth2_client = BasicClient::new(
615 ClientId::new(config.client_id.clone()),
616 Some(ClientSecret::new(config.client_secret.clone())),
617 AuthUrl::new(auth_url.clone())
618 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
619 Some(
620 TokenUrl::new(token_url.clone())
621 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
622 ),
623 )
624 .set_redirect_uri(
625 RedirectUrl::new(config.redirect_uri.clone())
626 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
627 );
628
629 Ok(Self {
630 config,
631 oauth2_client,
632 oidc_client: None,
633 })
634 }
635
636 async fn create_oidc_client(
638 config: &OAuth2Config,
639 issuer_uri: &str,
640 ) -> Result<CoreClient, OAuth2Error> {
641 let issuer_url = IssuerUrl::new(issuer_uri.to_string())
642 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?;
643
644 let provider_metadata = CoreProviderMetadata::discover_async(
646 issuer_url,
647 openidconnect::reqwest::async_http_client,
648 )
649 .await
650 .map_err(|e| OAuth2Error::Discovery(format!("{:?}", e)))?;
651
652 let client = CoreClient::from_provider_metadata(
654 provider_metadata,
655 OidcClientId::new(config.client_id.clone()),
656 Some(OidcClientSecret::new(config.client_secret.clone())),
657 )
658 .set_redirect_uri(
659 OidcRedirectUrl::new(config.redirect_uri.clone())
660 .map_err(|e| OAuth2Error::Configuration(e.to_string()))?,
661 );
662
663 Ok(client)
664 }
665
666 pub fn authorization_url(&self) -> (Url, CsrfToken, Option<PkceCodeVerifier>, Option<Nonce>) {
670 if let Some(oidc_client) = &self.oidc_client {
671 let mut auth_request = oidc_client.authorize_url(
673 CoreAuthenticationFlow::AuthorizationCode,
674 CsrfToken::new_random,
675 Nonce::new_random,
676 );
677
678 for scope in &self.config.scopes {
680 auth_request = auth_request.add_scope(openidconnect::Scope::new(scope.clone()));
681 }
682
683 let pkce_verifier = if self.config.use_pkce {
685 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
686 auth_request = auth_request.set_pkce_challenge(pkce_challenge);
687 Some(pkce_verifier)
688 } else {
689 None
690 };
691
692 let (url, state, nonce) = auth_request.url();
693 (url, state, pkce_verifier, Some(nonce))
694 } else {
695 let mut auth_request = self.oauth2_client.authorize_url(CsrfToken::new_random);
697
698 for scope in &self.config.scopes {
700 auth_request = auth_request.add_scope(Scope::new(scope.clone()));
701 }
702
703 let pkce_verifier = if self.config.use_pkce {
705 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
706 auth_request = auth_request.set_pkce_challenge(pkce_challenge);
707 Some(pkce_verifier)
708 } else {
709 None
710 };
711
712 let (url, state) = auth_request.url();
713 (url, state, pkce_verifier, None)
714 }
715 }
716
717 pub async fn exchange_code(
719 &self,
720 code: &str,
721 pkce_verifier: Option<PkceCodeVerifier>,
722 nonce: Option<&Nonce>,
723 ) -> Result<(OAuth2User, Option<OidcUser>), OAuth2Error> {
724 if let Some(oidc_client) = &self.oidc_client {
725 let mut token_request =
727 oidc_client.exchange_code(AuthorizationCode::new(code.to_string()));
728
729 if let Some(verifier) = pkce_verifier {
730 token_request = token_request.set_pkce_verifier(verifier);
731 }
732
733 let token_response = token_request
734 .request_async(openidconnect::reqwest::async_http_client)
735 .await
736 .map_err(|e| OAuth2Error::TokenExchange(format!("{:?}", e)))?;
737
738 let id_token = token_response
740 .id_token()
741 .ok_or_else(|| OAuth2Error::TokenValidation("Missing ID token".to_string()))?;
742
743 let id_token_verifier = oidc_client.id_token_verifier();
744 let nonce_ref = nonce.cloned().unwrap_or_else(|| Nonce::new(String::new()));
745 let claims = id_token.claims(&id_token_verifier, &nonce_ref).map_err(
746 |e: openidconnect::ClaimsVerificationError| {
747 OAuth2Error::TokenValidation(e.to_string())
748 },
749 )?;
750
751 let subject = claims.subject().as_str().to_string();
753 let issuer = claims.issuer().as_str().to_string();
754 let exp = claims.expiration().timestamp();
755 let iat = claims.issue_time().timestamp();
756
757 let mut oauth2_user = OAuth2User::new(&subject, &self.config.registration_id);
759 oauth2_user.access_token = Some(token_response.access_token().secret().clone());
760 oauth2_user.refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
761 oauth2_user.email_verified = claims.email_verified();
762
763 let oidc_user = OidcUser {
765 oauth2_user: oauth2_user.clone(),
766 id_token_claims: Some(IdTokenClaims {
767 iss: issuer,
768 sub: subject,
769 aud: vec![self.config.client_id.clone()],
770 exp,
771 iat,
772 auth_time: None,
773 nonce: None,
774 at_hash: None,
775 }),
776 id_token: Some(id_token.to_string()),
777 };
778
779 Ok((oauth2_user, Some(oidc_user)))
780 } else {
781 let mut token_request = self
783 .oauth2_client
784 .exchange_code(AuthorizationCode::new(code.to_string()));
785
786 if let Some(verifier) = pkce_verifier {
787 token_request = token_request.set_pkce_verifier(verifier);
788 }
789
790 let token_response = token_request
791 .request_async(async_http_client)
792 .await
793 .map_err(|e| OAuth2Error::TokenExchange(e.to_string()))?;
794
795 let mut oauth2_user = self
797 .fetch_user_info(token_response.access_token().secret())
798 .await?;
799
800 oauth2_user.access_token = Some(token_response.access_token().secret().clone());
801 oauth2_user.refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
802
803 Ok((oauth2_user, None))
804 }
805 }
806
807 async fn fetch_user_info(&self, access_token: &str) -> Result<OAuth2User, OAuth2Error> {
809 let userinfo_url = self
810 .config
811 .userinfo_uri
812 .as_ref()
813 .ok_or_else(|| OAuth2Error::UserInfo("Missing userinfo URI".to_string()))?;
814
815 let http_client = reqwest::Client::new();
816 let response = http_client
817 .get(userinfo_url)
818 .bearer_auth(access_token)
819 .header("Accept", "application/json")
820 .send()
821 .await
822 .map_err(|e| OAuth2Error::UserInfo(e.to_string()))?;
823
824 if !response.status().is_success() {
825 return Err(OAuth2Error::UserInfo(format!(
826 "HTTP {}: {}",
827 response.status(),
828 response.text().await.unwrap_or_default()
829 )));
830 }
831
832 let attributes: HashMap<String, serde_json::Value> = response
833 .json()
834 .await
835 .map_err(|e| OAuth2Error::UserInfo(e.to_string()))?;
836
837 let sub = self.extract_user_id(&attributes)?;
839 let mut user = OAuth2User::new(sub, &self.config.registration_id);
840 user.access_token = Some(access_token.to_string());
841 user.attributes = attributes.clone();
842
843 user.name = attributes
845 .get("name")
846 .and_then(|v| v.as_str())
847 .map(String::from);
848 user.email = attributes
849 .get("email")
850 .and_then(|v| v.as_str())
851 .map(String::from);
852 user.picture = attributes
853 .get("picture")
854 .or_else(|| attributes.get("avatar_url"))
855 .and_then(|v| v.as_str())
856 .map(String::from);
857
858 Ok(user)
859 }
860
861 fn extract_user_id(
863 &self,
864 attributes: &HashMap<String, serde_json::Value>,
865 ) -> Result<String, OAuth2Error> {
866 let id_fields = ["sub", "id", "user_id", "login"];
868
869 for field in &id_fields {
870 if let Some(value) = attributes.get(*field) {
871 if let Some(s) = value.as_str() {
872 return Ok(s.to_string());
873 }
874 if let Some(n) = value.as_i64() {
875 return Ok(n.to_string());
876 }
877 }
878 }
879
880 Err(OAuth2Error::UserInfo(
881 "Could not extract user ID".to_string(),
882 ))
883 }
884
885 pub fn config(&self) -> &OAuth2Config {
887 &self.config
888 }
889
890 pub fn has_oidc(&self) -> bool {
892 self.oidc_client.is_some()
893 }
894}
895
896#[derive(Clone, Default)]
900pub struct OAuth2ClientRepository {
901 clients: HashMap<String, OAuth2Client>,
902}
903
904impl OAuth2ClientRepository {
905 pub fn new() -> Self {
907 Self {
908 clients: HashMap::new(),
909 }
910 }
911
912 pub fn add_client(&mut self, client: OAuth2Client) {
914 self.clients
915 .insert(client.config.registration_id.clone(), client);
916 }
917
918 pub fn get_client(&self, registration_id: &str) -> Option<&OAuth2Client> {
920 self.clients.get(registration_id)
921 }
922
923 pub fn registration_ids(&self) -> Vec<&String> {
925 self.clients.keys().collect()
926 }
927
928 pub async fn from_configs(configs: Vec<OAuth2Config>) -> Result<Self, OAuth2Error> {
930 let mut repo = Self::new();
931 for config in configs {
932 let client = OAuth2Client::new(config).await?;
933 repo.add_client(client);
934 }
935 Ok(repo)
936 }
937}
938
939#[derive(Clone)]
944pub struct OAuth2Authenticator {
945 issuer: Option<String>,
947 jwks_uri: Option<String>,
949 username_attribute: String,
951}
952
953impl OAuth2Authenticator {
954 pub fn new() -> Self {
956 Self {
957 issuer: None,
958 jwks_uri: None,
959 username_attribute: "sub".to_string(),
960 }
961 }
962
963 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
965 self.issuer = Some(issuer.into());
966 self
967 }
968
969 pub fn jwks_uri(mut self, uri: impl Into<String>) -> Self {
971 self.jwks_uri = Some(uri.into());
972 self
973 }
974
975 pub fn username_attribute(mut self, attr: impl Into<String>) -> Self {
977 self.username_attribute = attr.into();
978 self
979 }
980
981 fn extract_token(&self, req: &ServiceRequest) -> Option<String> {
983 let auth_header = req.headers().get(AUTHORIZATION)?;
984 let auth_str = auth_header.to_str().ok()?;
985
986 auth_str
987 .strip_prefix("Bearer ")
988 .map(|token| token.to_string())
989 }
990}
991
992impl Default for OAuth2Authenticator {
993 fn default() -> Self {
994 Self::new()
995 }
996}
997
998impl Authenticator for OAuth2Authenticator {
999 fn get_user(&self, req: &ServiceRequest) -> Option<User> {
1000 let _token = self.extract_token(req)?;
1003
1004 None
1007 }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1013
1014 #[test]
1015 fn test_oauth2_config_builder() {
1016 let config = OAuth2Config::new("client-id", "secret", "http://localhost/callback")
1017 .provider(OAuth2Provider::Google)
1018 .add_scope("custom_scope");
1019
1020 assert_eq!(config.client_id, "client-id");
1021 assert_eq!(config.provider, OAuth2Provider::Google);
1022 assert!(config.scopes.contains(&"openid".to_string()));
1023 assert!(config.scopes.contains(&"email".to_string()));
1024 assert!(config.scopes.contains(&"custom_scope".to_string()));
1025 assert!(config.use_pkce);
1026 }
1027
1028 #[test]
1029 fn test_oauth2_provider_endpoints() {
1030 assert!(OAuth2Provider::Google.auth_url().is_some());
1031 assert!(OAuth2Provider::Google.token_url().is_some());
1032 assert!(OAuth2Provider::Google.userinfo_url().is_some());
1033 assert!(OAuth2Provider::Google.supports_oidc());
1034
1035 assert!(OAuth2Provider::GitHub.auth_url().is_some());
1036 assert!(!OAuth2Provider::GitHub.supports_oidc());
1037 }
1038
1039 #[test]
1040 fn test_oauth2_user() {
1041 let mut user = OAuth2User::new("user123", "google");
1042 user.email = Some("user@example.com".to_string());
1043 user.name = Some("Test User".to_string());
1044
1045 assert_eq!(user.username(), "user@example.com");
1046
1047 let auth_user = user.to_user();
1048 assert_eq!(auth_user.get_username(), "user@example.com");
1049 assert!(auth_user.get_roles().contains(&"USER".to_string()));
1050 }
1051
1052 #[test]
1053 fn test_provider_default_scopes() {
1054 let google_scopes = OAuth2Provider::Google.default_scopes();
1055 assert!(google_scopes.contains(&"openid"));
1056 assert!(google_scopes.contains(&"email"));
1057
1058 let github_scopes = OAuth2Provider::GitHub.default_scopes();
1059 assert!(github_scopes.contains(&"read:user"));
1060 }
1061
1062 #[test]
1063 fn test_oauth2_client_basic() {
1064 let config = OAuth2Config::new("client-id", "secret", "http://localhost/callback")
1065 .provider(OAuth2Provider::GitHub);
1066
1067 let client = OAuth2Client::new_basic(config).unwrap();
1068 assert!(!client.has_oidc());
1069 }
1070}