1impl Default for ProviderProfile {
3 fn default() -> Self {
4 Self::new()
5 }
6}
7use crate::errors::{AuthError, Result};
8use crate::tokens::AuthToken;
9use base64::Engine;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::fmt;
15use url::Url;
16
17#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
19pub enum OAuthProvider {
20 GitHub,
22
23 Google,
25
26 Microsoft,
28
29 Discord,
31
32 Twitter,
34
35 Facebook,
37
38 LinkedIn,
40
41 GitLab,
43
44 Custom {
46 name: String,
47 config: Box<OAuthProviderConfig>,
48 },
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53pub struct OAuthProviderConfig {
54 pub authorization_url: String,
56
57 pub token_url: String,
59
60 pub device_authorization_url: Option<String>,
62
63 pub userinfo_url: Option<String>,
65
66 pub revocation_url: Option<String>,
68
69 pub default_scopes: crate::types::Scopes,
71
72 pub supports_pkce: bool,
74
75 pub supports_refresh: bool,
77
78 pub supports_device_flow: bool,
80
81 pub additional_params: crate::types::AdditionalParams,
83}
84
85impl OAuthProviderConfig {
86 pub fn builder(
105 authorization_url: impl Into<String>,
106 token_url: impl Into<String>,
107 ) -> OAuthProviderConfigBuilder {
108 OAuthProviderConfigBuilder {
109 inner: OAuthProviderConfig {
110 authorization_url: authorization_url.into(),
111 token_url: token_url.into(),
112 device_authorization_url: None,
113 userinfo_url: None,
114 revocation_url: None,
115 default_scopes: crate::types::Scopes::empty(),
116 supports_pkce: false,
117 supports_refresh: false,
118 supports_device_flow: false,
119 additional_params: crate::types::AdditionalParams::new(),
120 },
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct OAuthProviderConfigBuilder {
128 inner: OAuthProviderConfig,
129}
130
131impl OAuthProviderConfigBuilder {
132 pub fn device_authorization_url(mut self, url: impl Into<String>) -> Self {
134 self.inner.device_authorization_url = Some(url.into());
135 self
136 }
137
138 pub fn userinfo_url(mut self, url: impl Into<String>) -> Self {
140 self.inner.userinfo_url = Some(url.into());
141 self
142 }
143
144 pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
146 self.inner.revocation_url = Some(url.into());
147 self
148 }
149
150 pub fn default_scope(mut self, scope: impl Into<String>) -> Self {
152 self.inner.default_scopes.push(scope.into());
153 self
154 }
155
156 pub fn supports_pkce(mut self, yes: bool) -> Self {
158 self.inner.supports_pkce = yes;
159 self
160 }
161
162 pub fn supports_refresh(mut self, yes: bool) -> Self {
164 self.inner.supports_refresh = yes;
165 self
166 }
167
168 pub fn supports_device_flow(mut self, yes: bool) -> Self {
170 self.inner.supports_device_flow = yes;
171 self
172 }
173
174 pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
176 self.inner
177 .additional_params
178 .insert(key.into(), value.into());
179 self
180 }
181
182 pub fn build(self) -> OAuthProviderConfig {
184 self.inner
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct DeviceAuthorizationResponse {
191 pub device_code: String,
193
194 pub user_code: String,
196
197 pub verification_uri: String,
199
200 pub verification_uri_complete: Option<String>,
202
203 pub interval: u64,
205
206 pub expires_in: u64,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ProviderProfile {
217 pub id: Option<String>,
219
220 pub provider: Option<String>,
222
223 pub username: Option<String>,
225
226 pub name: Option<String>,
228
229 pub email: Option<String>,
231
232 pub email_verified: Option<bool>,
234
235 pub picture: Option<String>,
237
238 pub locale: Option<String>,
240
241 pub additional_data: HashMap<String, serde_json::Value>,
243}
244
245#[cfg(feature = "postgres-storage")]
246use sqlx::{Decode, Postgres, Type, postgres::PgValueRef};
247
248#[cfg(feature = "postgres-storage")]
249impl<'r> Decode<'r, Postgres> for ProviderProfile {
250 fn decode(value: PgValueRef<'r>) -> std::result::Result<Self, sqlx::error::BoxDynError> {
251 let json: serde_json::Value = <serde_json::Value as Decode<Postgres>>::decode(value)?;
252 serde_json::from_value(json).map_err(|e| Box::new(e) as sqlx::error::BoxDynError)
253 }
254}
255
256#[cfg(feature = "postgres-storage")]
257impl Type<Postgres> for ProviderProfile {
258 fn type_info() -> sqlx::postgres::PgTypeInfo {
259 <serde_json::Value as Type<Postgres>>::type_info()
260 }
261 fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
262 <serde_json::Value as Type<Postgres>>::compatible(ty)
263 }
264}
265
266impl ProviderProfile {
267 pub fn new() -> Self {
269 Self {
270 id: None,
271 provider: None,
272 username: None,
273 name: None,
274 email: None,
275 email_verified: None,
276 picture: None,
277 locale: None,
278 additional_data: HashMap::new(),
279 }
280 }
281
282 pub fn with_id(mut self, id: impl Into<String>) -> Self {
284 self.id = Some(id.into());
285 self
286 }
287
288 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
290 self.provider = Some(provider.into());
291 self
292 }
293
294 pub fn with_username(mut self, username: Option<impl Into<String>>) -> Self {
296 self.username = username.map(Into::into);
297 self
298 }
299
300 pub fn with_name(mut self, name: Option<impl Into<String>>) -> Self {
302 self.name = name.map(Into::into);
303 self
304 }
305
306 pub fn with_email(mut self, email: Option<impl Into<String>>) -> Self {
308 self.email = email.map(Into::into);
309 self
310 }
311
312 pub fn with_email_verified(mut self, verified: bool) -> Self {
314 self.email_verified = Some(verified);
315 self
316 }
317
318 pub fn with_picture(mut self, picture: Option<impl Into<String>>) -> Self {
320 self.picture = picture.map(Into::into);
321 self
322 }
323
324 pub fn with_locale(mut self, locale: Option<impl Into<String>>) -> Self {
326 self.locale = locale.map(Into::into);
327 self
328 }
329
330 pub fn with_additional_data(
332 mut self,
333 key: impl Into<String>,
334 value: serde_json::Value,
335 ) -> Self {
336 self.additional_data.insert(key.into(), value);
337 self
338 }
339
340 pub fn from_token_response(
342 token: &OAuthTokenResponse,
343 provider: &OAuthProvider,
344 ) -> Option<Self> {
345 if let Some(id_token_value) = token.additional_fields.get("id_token")
347 && let Some(id_token) = id_token_value.as_str()
348 && let Ok(profile) = Self::from_id_token(id_token)
349 {
350 return Some(profile.with_provider(provider.to_string()));
351 }
352 None
353 }
354
355 pub fn from_id_token(id_token: &str) -> Result<Self> {
357 let parts: Vec<&str> = id_token.split('.').collect();
359 if parts.len() != 3 {
360 return Err(AuthError::validation("Invalid JWT format"));
361 }
362
363 let payload = parts[1];
365 let padding_len = payload.len() % 4;
366 let padded_payload = if padding_len > 0 {
367 format!("{}{}", payload, "=".repeat(4 - padding_len))
368 } else {
369 payload.to_string()
370 };
371
372 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
374 .decode(&padded_payload)
375 .map_err(|e| AuthError::validation(format!("Failed to decode JWT: {}", e)))?;
376
377 let json: Value = serde_json::from_slice(&decoded)
379 .map_err(|e| AuthError::validation(format!("Failed to parse JWT payload: {}", e)))?;
380
381 let mut profile = Self::new();
383
384 if let Some(sub) = json.get("sub").and_then(|v| v.as_str()) {
386 profile = profile.with_id(sub);
387 } else if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
388 profile = profile.with_id(id);
389 } else {
390 return Err(AuthError::validation("JWT missing subject claim"));
391 }
392
393 if let Some(name) = json.get("name").and_then(|v| v.as_str()) {
395 profile = profile.with_name(Some(name));
396 }
397
398 if let Some(email) = json.get("email").and_then(|v| v.as_str()) {
399 profile = profile.with_email(Some(email));
400 }
401
402 if let Some(verified) = json.get("email_verified").and_then(|v| v.as_bool()) {
403 profile = profile.with_email_verified(verified);
404 }
405
406 if let Some(preferred_username) = json.get("preferred_username").and_then(|v| v.as_str()) {
407 profile = profile.with_username(Some(preferred_username));
408 }
409
410 if let Some(picture) = json.get("picture").and_then(|v| v.as_str()) {
411 profile = profile.with_picture(Some(picture));
412 }
413
414 if let Some(locale) = json.get("locale").and_then(|v| v.as_str()) {
415 profile = profile.with_locale(Some(locale));
416 }
417
418 profile = profile.with_additional_data("id_token_claims", json);
420
421 Ok(profile)
422 }
423
424 pub fn to_auth_token(&self, access_token: String) -> AuthToken {
427 self.to_auth_token_with_lifetime(access_token, std::time::Duration::from_secs(3600))
428 }
429
430 pub fn to_auth_token_with_lifetime(
432 &self,
433 access_token: String,
434 lifetime: std::time::Duration,
435 ) -> AuthToken {
436 let user_id = self.id.as_deref().unwrap_or("unknown").to_string();
437 let auth_method = self.provider.as_deref().unwrap_or("oauth").to_string();
438
439 let mut token = AuthToken::new(user_id.clone(), access_token, lifetime, auth_method);
440 token.subject = self.id.clone();
441 token.issuer = self.provider.clone();
442 token.user_profile = Some(self.clone());
443 token
444 }
445
446 pub fn has_id(&self) -> bool {
448 self.id.is_some()
449 }
450
451 pub fn display_name(&self) -> Option<&str> {
453 self.name.as_deref().or(self.username.as_deref())
454 }
455}
456
457#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct OAuthTokenResponse {
460 pub access_token: String,
462
463 pub token_type: String,
465
466 pub expires_in: Option<u64>,
468
469 pub refresh_token: Option<String>,
471
472 pub scope: Option<String>,
474
475 #[serde(flatten)]
477 pub additional_fields: HashMap<String, serde_json::Value>,
478}
479
480#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct OAuthUserInfo {
483 pub id: String,
485
486 pub username: Option<String>,
488
489 pub name: Option<String>,
491
492 pub email: Option<String>,
494
495 pub email_verified: Option<bool>,
497
498 pub picture: Option<String>,
500
501 pub locale: Option<String>,
503
504 #[serde(flatten)]
506 pub additional_fields: HashMap<String, serde_json::Value>,
507}
508
509impl OAuthProvider {
510 pub fn config(&self) -> OAuthProviderConfig {
512 match self {
513 Self::GitHub => OAuthProviderConfig {
514 authorization_url: "https://github.com/login/oauth/authorize".to_string(),
515 token_url: "https://github.com/login/oauth/access_token".to_string(),
516 device_authorization_url: Some("https://github.com/login/device/code".to_string()),
517 userinfo_url: Some("https://api.github.com/user".to_string()),
518 revocation_url: None,
519 default_scopes: vec!["user:email".to_string()].into(),
520 supports_pkce: true,
521 supports_refresh: false,
522 supports_device_flow: true,
523 additional_params: crate::types::AdditionalParams::new(),
524 },
525
526 Self::Google => OAuthProviderConfig {
527 authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
528 token_url: "https://oauth2.googleapis.com/token".to_string(),
529 device_authorization_url: Some(
530 "https://oauth2.googleapis.com/device/code".to_string(),
531 ),
532 userinfo_url: Some("https://www.googleapis.com/oauth2/v2/userinfo".to_string()),
533 revocation_url: Some("https://oauth2.googleapis.com/revoke".to_string()),
534 default_scopes: vec![
535 "openid".to_string(),
536 "profile".to_string(),
537 "email".to_string(),
538 ]
539 .into(),
540 supports_pkce: true,
541 supports_refresh: true,
542 supports_device_flow: true,
543 additional_params: crate::types::AdditionalParams::new(),
544 },
545
546 Self::Microsoft => OAuthProviderConfig {
547 authorization_url: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
548 .to_string(),
549 token_url: "https://login.microsoftonline.com/common/oauth2/v2.0/token".to_string(),
550 device_authorization_url: Some(
551 "https://login.microsoftonline.com/common/oauth2/v2.0/devicecode".to_string(),
552 ),
553 userinfo_url: Some("https://graph.microsoft.com/v1.0/me".to_string()),
554 revocation_url: None,
555 default_scopes: vec![
556 "openid".to_string(),
557 "profile".to_string(),
558 "email".to_string(),
559 ]
560 .into(),
561 supports_pkce: true,
562 supports_refresh: true,
563 supports_device_flow: true,
564 additional_params: crate::types::AdditionalParams::new(),
565 },
566
567 Self::Discord => OAuthProviderConfig {
568 authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
569 token_url: "https://discord.com/api/oauth2/token".to_string(),
570 device_authorization_url: None,
571 userinfo_url: Some("https://discord.com/api/users/@me".to_string()),
572 revocation_url: Some("https://discord.com/api/oauth2/token/revoke".to_string()),
573 default_scopes: vec!["identify".to_string(), "email".to_string()].into(),
574 supports_pkce: false,
575 supports_refresh: true,
576 supports_device_flow: false,
577 additional_params: crate::types::AdditionalParams::new(),
578 },
579
580 Self::Twitter => OAuthProviderConfig {
581 authorization_url: "https://twitter.com/i/oauth2/authorize".to_string(),
582 token_url: "https://api.twitter.com/2/oauth2/token".to_string(),
583 device_authorization_url: None,
584 userinfo_url: Some("https://api.twitter.com/2/users/me".to_string()),
585 revocation_url: Some("https://api.twitter.com/2/oauth2/revoke".to_string()),
586 default_scopes: vec!["tweet.read".to_string(), "users.read".to_string()].into(),
587 supports_pkce: true,
588 supports_refresh: true,
589 supports_device_flow: false,
590 additional_params: crate::types::AdditionalParams::new(),
591 },
592
593 Self::Facebook => OAuthProviderConfig {
594 authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
595 token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
596 device_authorization_url: None,
597 userinfo_url: Some("https://graph.facebook.com/me".to_string()),
598 revocation_url: None,
599 default_scopes: vec!["email".to_string(), "public_profile".to_string()].into(),
600 supports_pkce: false,
601 supports_refresh: false,
602 supports_device_flow: false,
603 additional_params: crate::types::AdditionalParams::new(),
604 },
605
606 Self::LinkedIn => OAuthProviderConfig {
607 authorization_url: "https://www.linkedin.com/oauth/v2/authorization".to_string(),
608 token_url: "https://www.linkedin.com/oauth/v2/accessToken".to_string(),
609 device_authorization_url: None,
610 userinfo_url: Some("https://api.linkedin.com/v2/me".to_string()),
611 revocation_url: None,
612 default_scopes: vec!["r_liteprofile".to_string(), "r_emailaddress".to_string()]
613 .into(),
614 supports_pkce: false,
615 supports_refresh: true,
616 supports_device_flow: false,
617 additional_params: crate::types::AdditionalParams::new(),
618 },
619
620 Self::GitLab => OAuthProviderConfig {
621 authorization_url: "https://gitlab.com/oauth/authorize".to_string(),
622 token_url: "https://gitlab.com/oauth/token".to_string(),
623 device_authorization_url: None,
624 userinfo_url: Some("https://gitlab.com/api/v4/user".to_string()),
625 revocation_url: Some("https://gitlab.com/oauth/revoke".to_string()),
626 default_scopes: vec!["read_user".to_string()].into(),
627 supports_pkce: true,
628 supports_refresh: true,
629 supports_device_flow: false,
630 additional_params: crate::types::AdditionalParams::new(),
631 },
632
633 Self::Custom { config, .. } => *config.clone(),
634 }
635 }
636
637 pub fn name(&self) -> &str {
639 match self {
640 Self::GitHub => "github",
641 Self::Google => "google",
642 Self::Microsoft => "microsoft",
643 Self::Discord => "discord",
644 Self::Twitter => "twitter",
645 Self::Facebook => "facebook",
646 Self::LinkedIn => "linkedin",
647 Self::GitLab => "gitlab",
648 Self::Custom { name, .. } => name,
649 }
650 }
651
652 pub fn custom(name: impl Into<String>, config: OAuthProviderConfig) -> Self {
654 Self::Custom {
655 name: name.into(),
656 config: Box::new(config),
657 }
658 }
659
660 pub fn build_authorization_url(
662 &self,
663 client_id: &str,
664 redirect_uri: &str,
665 state: &str,
666 scopes: Option<&[String]>,
667 code_challenge: Option<&str>,
668 ) -> Result<String> {
669 let config = self.config();
670 let mut url = Url::parse(&config.authorization_url)
671 .map_err(|e| AuthError::config(format!("Invalid authorization URL: {e}")))?;
672
673 let scopes = scopes.unwrap_or(config.default_scopes.as_slice());
674
675 {
676 let mut query = url.query_pairs_mut();
677 query.append_pair("client_id", client_id);
678 query.append_pair("redirect_uri", redirect_uri);
679 query.append_pair("response_type", "code");
680 query.append_pair("state", state);
681
682 if !scopes.is_empty() {
683 query.append_pair("scope", &scopes.join(" "));
684 }
685
686 if config.supports_pkce
688 && let Some(challenge) = code_challenge
689 {
690 query.append_pair("code_challenge", challenge);
691 query.append_pair("code_challenge_method", "S256");
692 }
693
694 for (key, value) in &config.additional_params {
696 query.append_pair(key, value);
697 }
698 }
699
700 Ok(url.to_string())
701 }
702
703 pub async fn exchange_code(
705 &self,
706 client_id: &str,
707 client_secret: &str,
708 authorization_code: &str,
709 redirect_uri: &str,
710 code_verifier: Option<&str>,
711 ) -> Result<OAuthTokenResponse> {
712 let config = self.config();
713 let client = reqwest::Client::new();
714
715 let mut params = HashMap::new();
716 params.insert("grant_type".to_string(), "authorization_code".to_string());
717 params.insert("client_id".to_string(), client_id.to_string());
718 params.insert("client_secret".to_string(), client_secret.to_string());
719 params.insert("code".to_string(), authorization_code.to_string());
720 params.insert("redirect_uri".to_string(), redirect_uri.to_string());
721
722 if let Some(verifier) = code_verifier {
724 params.insert("code_verifier".to_string(), verifier.to_string());
725 }
726
727 let response = client.post(&config.token_url).form(¶ms).send().await?;
728
729 if !response.status().is_success() {
730 let error_text = response.text().await.unwrap_or_default();
731 return Err(AuthError::auth_method(
732 self.name(),
733 format!("Token exchange failed: {error_text}"),
734 ));
735 }
736
737 let token_response: OAuthTokenResponse = response.json().await?;
738 Ok(token_response)
739 }
740
741 pub async fn refresh_token(
743 &self,
744 client_id: &str,
745 client_secret: &str,
746 refresh_token: &str,
747 ) -> Result<OAuthTokenResponse> {
748 let config = self.config();
749
750 if !config.supports_refresh {
751 return Err(AuthError::auth_method(
752 self.name(),
753 "Provider does not support token refresh".to_string(),
754 ));
755 }
756
757 let client = reqwest::Client::new();
758
759 let mut params = HashMap::new();
760 params.insert("grant_type".to_string(), "refresh_token".to_string());
761 params.insert("client_id".to_string(), client_id.to_string());
762 params.insert("client_secret".to_string(), client_secret.to_string());
763 params.insert("refresh_token".to_string(), refresh_token.to_string());
764
765 let response = client.post(&config.token_url).form(¶ms).send().await?;
766
767 if !response.status().is_success() {
768 let error_text = response.text().await.unwrap_or_default();
769 return Err(AuthError::auth_method(
770 self.name(),
771 format!("Token refresh failed: {error_text}"),
772 ));
773 }
774
775 let token_response: OAuthTokenResponse = response.json().await?;
776 Ok(token_response)
777 }
778
779 pub async fn get_user_info(&self, access_token: &str) -> Result<OAuthUserInfo> {
781 let config = self.config();
782
783 let userinfo_url = config.userinfo_url.ok_or_else(|| {
784 AuthError::auth_method(
785 self.name(),
786 "Provider does not support user info endpoint".to_string(),
787 )
788 })?;
789
790 let client = reqwest::Client::new();
791 let response = client
792 .get(&userinfo_url)
793 .bearer_auth(access_token)
794 .send()
795 .await?;
796
797 if !response.status().is_success() {
798 let error_text = response.text().await.unwrap_or_default();
799 return Err(AuthError::auth_method(
800 self.name(),
801 format!("User info request failed: {error_text}"),
802 ));
803 }
804
805 let user_data: serde_json::Value = response.json().await?;
806
807 let user_info = self.parse_user_info(user_data)?;
809 Ok(user_info)
810 }
811
812 fn parse_user_info(&self, data: serde_json::Value) -> Result<OAuthUserInfo> {
814 let mut additional_fields = HashMap::new();
815
816 let user_info = match self {
817 Self::GitHub => {
818 let id = data["id"]
819 .as_u64()
820 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
821 .to_string();
822
823 OAuthUserInfo {
824 id,
825 username: data["login"].as_str().map(|s| s.to_string()),
826 email: data["email"].as_str().map(|s| s.to_string()),
827 name: data["name"].as_str().map(|s| s.to_string()),
828 picture: data["avatar_url"].as_str().map(|s| s.to_string()),
829 email_verified: None, locale: None,
831 additional_fields,
832 }
833 }
834
835 Self::Google => {
836 let id = data["id"]
837 .as_str()
838 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
839 .to_string();
840
841 OAuthUserInfo {
842 id,
843 username: None, email: data["email"].as_str().map(|s| s.to_string()),
845 name: data["name"].as_str().map(|s| s.to_string()),
846 picture: data["picture"].as_str().map(|s| s.to_string()),
847 email_verified: data["verified_email"].as_bool(),
848 locale: data["locale"].as_str().map(|s| s.to_string()),
849 additional_fields,
850 }
851 }
852
853 _ => {
855 let id = data["id"]
857 .as_str()
858 .or_else(|| data["sub"].as_str())
859 .or_else(|| data["user_id"].as_str())
860 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
861 .to_string();
862
863 if let serde_json::Value::Object(map) = data {
865 additional_fields = map.into_iter().collect();
866 }
867
868 OAuthUserInfo {
869 id,
870 username: additional_fields
871 .get("username")
872 .or_else(|| additional_fields.get("login"))
873 .and_then(|v| v.as_str())
874 .map(|s| s.to_string()),
875 email: additional_fields
876 .get("email")
877 .and_then(|v| v.as_str())
878 .map(|s| s.to_string()),
879 name: additional_fields
880 .get("name")
881 .or_else(|| additional_fields.get("display_name"))
882 .and_then(|v| v.as_str())
883 .map(|s| s.to_string()),
884 picture: additional_fields
885 .get("avatar_url")
886 .or_else(|| additional_fields.get("picture"))
887 .and_then(|v| v.as_str())
888 .map(|s| s.to_string()),
889 email_verified: additional_fields
890 .get("email_verified")
891 .and_then(|v| v.as_bool()),
892 locale: additional_fields
893 .get("locale")
894 .and_then(|v| v.as_str())
895 .map(|s| s.to_string()),
896 additional_fields,
897 }
898 }
899 };
900
901 Ok(user_info)
902 }
903
904 pub async fn revoke_token(&self, access_token: &str) -> Result<()> {
906 let config = self.config();
907
908 let revocation_url = config.revocation_url.ok_or_else(|| {
909 AuthError::auth_method(
910 self.name(),
911 "Provider does not support token revocation".to_string(),
912 )
913 })?;
914
915 let client = reqwest::Client::new();
916 let mut params = HashMap::new();
917 params.insert("token".to_string(), access_token.to_string());
918
919 let response = client.post(&revocation_url).form(¶ms).send().await?;
920
921 if !response.status().is_success() {
922 let error_text = response.text().await.unwrap_or_default();
923 return Err(AuthError::auth_method(
924 self.name(),
925 format!("Token revocation failed: {error_text}"),
926 ));
927 }
928
929 Ok(())
930 }
931
932 pub async fn device_authorization(
934 &self,
935 client_id: &str,
936 scope: Option<&[String]>,
937 ) -> Result<DeviceAuthorizationResponse> {
938 let config = self.config();
939
940 if !config.supports_device_flow {
941 return Err(AuthError::auth_method(
942 self.name(),
943 "Provider does not support device authorization flow".to_string(),
944 ));
945 }
946
947 let client = reqwest::Client::new();
948
949 let scope_string = scope.unwrap_or(&config.default_scopes).join(" ");
950 let mut params = HashMap::new();
951 params.insert("client_id".to_string(), client_id.to_string());
952 params.insert("scope".to_string(), scope_string);
953
954 let device_auth_url = config.device_authorization_url.as_deref().ok_or_else(|| {
955 AuthError::auth_method(
956 self.name(),
957 "Device authorization URL is not configured for this provider".to_string(),
958 )
959 })?;
960
961 let response = client.post(device_auth_url).form(¶ms).send().await?;
962
963 if !response.status().is_success() {
964 let error_text = response.text().await.unwrap_or_default();
965 return Err(AuthError::auth_method(
966 self.name(),
967 format!("Device authorization request failed: {error_text}"),
968 ));
969 }
970
971 let device_response: DeviceAuthorizationResponse = response.json().await?;
972 Ok(device_response)
973 }
974
975 pub async fn poll_device_code(
977 &self,
978 client_id: &str,
979 device_code: &str,
980 _interval: Option<u64>,
981 ) -> Result<OAuthTokenResponse> {
982 let config = self.config();
983
984 if !config.supports_device_flow {
985 return Err(AuthError::auth_method(
986 self.name(),
987 "Provider does not support device authorization flow".to_string(),
988 ));
989 }
990
991 let client = reqwest::Client::new();
992
993 let mut params = HashMap::new();
994 params.insert("client_id".to_string(), client_id.to_string());
995 params.insert(
996 "grant_type".to_string(),
997 "urn:ietf:params:oauth:grant-type:device_code".to_string(),
998 );
999 params.insert("device_code".to_string(), device_code.to_string());
1000
1001 let response = client.post(&config.token_url).form(¶ms).send().await?;
1002
1003 if !response.status().is_success() {
1004 let error_text = response.text().await.unwrap_or_default();
1005 return Err(AuthError::auth_method(
1006 self.name(),
1007 format!("Token request failed: {error_text}"),
1008 ));
1009 }
1010
1011 let token_response: OAuthTokenResponse = response.json().await?;
1012 Ok(token_response)
1013 }
1014}
1015
1016pub fn generate_state() -> String {
1018 let mut bytes = [0u8; 32];
1019 use rand::Rng;
1020 rand::rng().fill_bytes(&mut bytes);
1021 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
1022}
1023
1024pub fn generate_pkce() -> (String, String) {
1026 use rand::Rng;
1027 use ring::digest;
1028
1029 let mut rng = rand::rng();
1031 let mut bytes = [0u8; 96]; rng.fill_bytes(&mut bytes);
1033 let code_verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
1034
1035 let digest = digest::digest(&digest::SHA256, code_verifier.as_bytes());
1037 let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest.as_ref());
1038
1039 (code_verifier, code_challenge)
1040}
1041
1042pub struct ProfileExtractor {
1044 client: Client,
1045}
1046
1047impl ProfileExtractor {
1048 pub fn new() -> Self {
1050 Self {
1051 client: Client::new(),
1052 }
1053 }
1054
1055 pub async fn extract_profile(
1057 &self,
1058 token: &AuthToken,
1059 provider: &OAuthProvider,
1060 ) -> Result<ProviderProfile> {
1061 match provider {
1062 OAuthProvider::GitHub => self.extract_github_profile(token).await,
1063 OAuthProvider::Google => self.extract_google_profile(token).await,
1064 OAuthProvider::Microsoft => self.extract_microsoft_profile(token).await,
1065 OAuthProvider::Discord => self.extract_discord_profile(token).await,
1066 OAuthProvider::GitLab => self.extract_gitlab_profile(token).await,
1067 OAuthProvider::Custom { name, config } => {
1068 self.extract_custom_profile(token, name, config).await
1069 }
1070 _ => Err(AuthError::UnsupportedProvider(format!(
1071 "Profile extraction not supported for {:?}",
1072 provider
1073 ))),
1074 }
1075 }
1076
1077 async fn extract_github_profile(&self, token: &AuthToken) -> Result<ProviderProfile> {
1079 let response = self
1080 .client
1081 .get("https://api.github.com/user")
1082 .bearer_auth(&token.access_token)
1083 .send()
1084 .await
1085 .map_err(|e| AuthError::internal(e.to_string()))?;
1086
1087 let json: Value = response
1088 .json()
1089 .await
1090 .map_err(|e| AuthError::internal(e.to_string()))?;
1091
1092 let mut profile = ProviderProfile::new();
1093 profile = profile.with_id(json["id"].as_u64().unwrap_or(0).to_string());
1094 profile = profile.with_provider("github".to_string());
1095
1096 if let Some(login) = json["login"].as_str() {
1097 profile.username = Some(login.to_string());
1098 }
1099
1100 if let Some(name) = json["name"].as_str() {
1101 profile.name = Some(name.to_string());
1102 }
1103
1104 if let Some(email) = json["email"].as_str() {
1105 profile.email = Some(email.to_string());
1106 }
1107
1108 if let Some(avatar_url) = json["avatar_url"].as_str() {
1109 profile.picture = Some(avatar_url.to_string());
1110 }
1111
1112 if let Some(company) = json["company"].as_str() {
1114 profile
1115 .additional_data
1116 .insert("company".to_string(), Value::String(company.to_string()));
1117 }
1118
1119 if let Some(blog) = json["blog"].as_str() {
1120 profile
1121 .additional_data
1122 .insert("blog".to_string(), Value::String(blog.to_string()));
1123 }
1124
1125 if let Some(bio) = json["bio"].as_str() {
1126 profile
1127 .additional_data
1128 .insert("bio".to_string(), Value::String(bio.to_string()));
1129 }
1130
1131 Ok(profile)
1132 }
1133
1134 async fn extract_google_profile(&self, token: &AuthToken) -> Result<ProviderProfile> {
1136 let response = self
1137 .client
1138 .get("https://www.googleapis.com/oauth2/v2/userinfo")
1139 .bearer_auth(&token.access_token)
1140 .send()
1141 .await
1142 .map_err(|e| AuthError::internal(e.to_string()))?;
1143
1144 let json: Value = response
1145 .json()
1146 .await
1147 .map_err(|e| AuthError::internal(e.to_string()))?;
1148
1149 let mut profile = ProviderProfile::new();
1150 profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1151 profile = profile.with_provider("google".to_string());
1152
1153 if let Some(name) = json["name"].as_str() {
1154 profile.name = Some(name.to_string());
1155 }
1156
1157 if let Some(email) = json["email"].as_str() {
1158 profile.email = Some(email.to_string());
1159 }
1160
1161 if let Some(verified) = json["verified_email"].as_bool() {
1162 profile.email_verified = Some(verified);
1163 }
1164
1165 if let Some(picture) = json["picture"].as_str() {
1166 profile.picture = Some(picture.to_string());
1167 }
1168
1169 if let Some(locale) = json["locale"].as_str() {
1170 profile.locale = Some(locale.to_string());
1171 }
1172
1173 Ok(profile)
1174 }
1175
1176 async fn extract_microsoft_profile(&self, token: &AuthToken) -> Result<ProviderProfile> {
1178 let response = self
1179 .client
1180 .get("https://graph.microsoft.com/v1.0/me")
1181 .bearer_auth(&token.access_token)
1182 .send()
1183 .await
1184 .map_err(|e| AuthError::internal(e.to_string()))?;
1185
1186 let json: Value = response
1187 .json()
1188 .await
1189 .map_err(|e| AuthError::internal(e.to_string()))?;
1190
1191 let mut profile = ProviderProfile::new();
1192 profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1193 profile = profile.with_provider("microsoft".to_string());
1194
1195 if let Some(display_name) = json["displayName"].as_str() {
1196 profile.name = Some(display_name.to_string());
1197 }
1198
1199 if let Some(user_principal_name) = json["userPrincipalName"].as_str() {
1200 profile.username = Some(user_principal_name.to_string());
1201 }
1202
1203 if let Some(mail) = json["mail"].as_str() {
1204 profile.email = Some(mail.to_string());
1205 }
1206
1207 if let Some(preferred_language) = json["preferredLanguage"].as_str() {
1208 profile.locale = Some(preferred_language.to_string());
1209 }
1210
1211 if let Some(job_title) = json["jobTitle"].as_str() {
1213 profile
1214 .additional_data
1215 .insert("jobTitle".to_string(), Value::String(job_title.to_string()));
1216 }
1217
1218 if let Some(office_location) = json["officeLocation"].as_str() {
1219 profile.additional_data.insert(
1220 "officeLocation".to_string(),
1221 Value::String(office_location.to_string()),
1222 );
1223 }
1224
1225 Ok(profile)
1226 }
1227
1228 async fn extract_discord_profile(&self, token: &AuthToken) -> Result<ProviderProfile> {
1230 let response = self
1231 .client
1232 .get("https://discord.com/api/users/@me")
1233 .bearer_auth(&token.access_token)
1234 .send()
1235 .await
1236 .map_err(|e| AuthError::internal(e.to_string()))?;
1237
1238 let json: Value = response
1239 .json()
1240 .await
1241 .map_err(|e| AuthError::internal(e.to_string()))?;
1242
1243 let mut profile = ProviderProfile::new();
1244 profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1245 profile = profile.with_provider("discord".to_string());
1246
1247 if let Some(username) = json["username"].as_str() {
1248 profile.username = Some(username.to_string());
1249 }
1250
1251 if let Some(discriminator) = json["discriminator"].as_str() {
1252 profile.name = Some(format!(
1253 "{}#{}",
1254 json["username"].as_str().unwrap_or(""),
1255 discriminator
1256 ));
1257 }
1258
1259 if let Some(email) = json["email"].as_str() {
1260 profile.email = Some(email.to_string());
1261 }
1262
1263 if let Some(verified) = json["verified"].as_bool() {
1264 profile.email_verified = Some(verified);
1265 }
1266
1267 if let Some(avatar) = json["avatar"].as_str() {
1268 let user_id = json["id"].as_str().unwrap_or("");
1269 profile.picture = Some(format!(
1270 "https://cdn.discordapp.com/avatars/{}/{}.png",
1271 user_id, avatar
1272 ));
1273 }
1274
1275 if let Some(locale) = json["locale"].as_str() {
1276 profile.locale = Some(locale.to_string());
1277 }
1278
1279 Ok(profile)
1280 }
1281
1282 async fn extract_gitlab_profile(&self, token: &AuthToken) -> Result<ProviderProfile> {
1284 let response = self
1285 .client
1286 .get("https://gitlab.com/api/v4/user")
1287 .bearer_auth(&token.access_token)
1288 .send()
1289 .await
1290 .map_err(|e| AuthError::internal(e.to_string()))?;
1291
1292 let json: Value = response
1293 .json()
1294 .await
1295 .map_err(|e| AuthError::internal(e.to_string()))?;
1296
1297 let mut profile = ProviderProfile::new();
1298 profile = profile.with_id(json["id"].as_u64().unwrap_or(0).to_string());
1299 profile = profile.with_provider("gitlab".to_string());
1300
1301 if let Some(username) = json["username"].as_str() {
1302 profile.username = Some(username.to_string());
1303 }
1304
1305 if let Some(name) = json["name"].as_str() {
1306 profile.name = Some(name.to_string());
1307 }
1308
1309 if let Some(email) = json["email"].as_str() {
1310 profile.email = Some(email.to_string());
1311 }
1312
1313 if let Some(avatar_url) = json["avatar_url"].as_str() {
1314 profile.picture = Some(avatar_url.to_string());
1315 }
1316
1317 if let Some(web_url) = json["web_url"].as_str() {
1319 profile
1320 .additional_data
1321 .insert("web_url".to_string(), Value::String(web_url.to_string()));
1322 }
1323
1324 if let Some(bio) = json["bio"].as_str() {
1325 profile
1326 .additional_data
1327 .insert("bio".to_string(), Value::String(bio.to_string()));
1328 }
1329
1330 Ok(profile)
1331 }
1332
1333 async fn extract_custom_profile(
1335 &self,
1336 token: &AuthToken,
1337 provider_name: &str,
1338 config: &OAuthProviderConfig,
1339 ) -> Result<ProviderProfile> {
1340 if let Some(user_info_url) = &config.userinfo_url {
1341 let response = self
1342 .client
1343 .get(user_info_url)
1344 .bearer_auth(&token.access_token)
1345 .send()
1346 .await
1347 .map_err(|e| AuthError::internal(e.to_string()))?;
1348
1349 let json: Value = response
1350 .json()
1351 .await
1352 .map_err(|e| AuthError::internal(e.to_string()))?;
1353
1354 let mut profile = ProviderProfile::new();
1355 profile = profile.with_id(
1356 json["id"]
1357 .as_str()
1358 .or_else(|| json["sub"].as_str())
1359 .unwrap_or("")
1360 .to_string(),
1361 );
1362 profile = profile.with_provider(provider_name.to_string());
1363
1364 if let Some(username) = json["username"].as_str().or_else(|| json["login"].as_str()) {
1366 profile.username = Some(username.to_string());
1367 }
1368
1369 if let Some(name) = json["name"]
1370 .as_str()
1371 .or_else(|| json["display_name"].as_str())
1372 {
1373 profile.name = Some(name.to_string());
1374 }
1375
1376 if let Some(email) = json["email"].as_str() {
1377 profile.email = Some(email.to_string());
1378 }
1379
1380 if let Some(verified) = json["email_verified"]
1381 .as_bool()
1382 .or_else(|| json["verified"].as_bool())
1383 {
1384 profile.email_verified = Some(verified);
1385 }
1386
1387 if let Some(picture) = json["picture"]
1388 .as_str()
1389 .or_else(|| json["avatar_url"].as_str())
1390 {
1391 profile.picture = Some(picture.to_string());
1392 }
1393
1394 if let Some(locale) = json["locale"].as_str().or_else(|| json["lang"].as_str()) {
1395 profile.locale = Some(locale.to_string());
1396 }
1397
1398 for (key, value) in json.as_object().unwrap_or(&serde_json::Map::new()) {
1400 if ![
1401 "id",
1402 "sub",
1403 "username",
1404 "login",
1405 "name",
1406 "display_name",
1407 "email",
1408 "email_verified",
1409 "verified",
1410 "picture",
1411 "avatar_url",
1412 "locale",
1413 "lang",
1414 ]
1415 .contains(&key.as_str())
1416 {
1417 profile.additional_data.insert(key.clone(), value.clone());
1418 }
1419 }
1420
1421 Ok(profile)
1422 } else {
1423 Err(AuthError::config(
1424 "Custom provider requires user_info_url".to_string(),
1425 ))
1426 }
1427 }
1428}
1429
1430impl Default for ProfileExtractor {
1431 fn default() -> Self {
1432 Self::new()
1433 }
1434}
1435
1436impl fmt::Display for OAuthProvider {
1437 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1438 match self {
1439 OAuthProvider::GitHub => write!(f, "github"),
1440 OAuthProvider::Google => write!(f, "google"),
1441 OAuthProvider::Microsoft => write!(f, "microsoft"),
1442 OAuthProvider::Discord => write!(f, "discord"),
1443 OAuthProvider::Twitter => write!(f, "twitter"),
1444 OAuthProvider::Facebook => write!(f, "facebook"),
1445 OAuthProvider::LinkedIn => write!(f, "linkedin"),
1446 OAuthProvider::GitLab => write!(f, "gitlab"),
1447 OAuthProvider::Custom { name, .. } => write!(f, "{}", name),
1448 }
1449 }
1450}
1451
1452#[cfg(test)]
1453mod tests {
1454 use super::*;
1455
1456 #[test]
1457 fn test_provider_config() {
1458 let github = OAuthProvider::GitHub;
1459 let config = github.config();
1460
1461 assert_eq!(
1462 config.authorization_url,
1463 "https://github.com/login/oauth/authorize"
1464 );
1465 assert_eq!(
1466 config.token_url,
1467 "https://github.com/login/oauth/access_token"
1468 );
1469 assert!(config.supports_pkce);
1470 }
1471
1472 #[test]
1473 fn test_authorization_url() {
1474 let github = OAuthProvider::GitHub;
1475 let url = github
1476 .build_authorization_url(
1477 "client123",
1478 "https://example.com/callback",
1479 "state123",
1480 None,
1481 Some("challenge123"),
1482 )
1483 .unwrap();
1484
1485 assert!(url.contains("client_id=client123"));
1486 assert!(url.contains("redirect_uri=https%3A%2F%2Fexample.com%2Fcallback"));
1487 assert!(url.contains("state=state123"));
1488 assert!(url.contains("code_challenge=challenge123"));
1489 }
1490
1491 #[test]
1492 fn test_generate_state() {
1493 let state1 = generate_state();
1494 let state2 = generate_state();
1495
1496 assert_eq!(state1.len(), 43);
1497 assert_eq!(state2.len(), 43);
1498 assert_ne!(state1, state2);
1499 }
1500
1501 #[test]
1502 fn test_generate_pkce() {
1503 let (verifier1, challenge1) = generate_pkce();
1504 let (verifier2, challenge2) = generate_pkce();
1505
1506 assert_eq!(verifier1.len(), 128);
1507 assert_eq!(verifier2.len(), 128);
1508 assert_ne!(verifier1, verifier2);
1509 assert_ne!(challenge1, challenge2);
1510 }
1511}