1impl Default for UserProfile {
3 fn default() -> Self {
4 Self::new()
5 }
6}
7use crate::errors::{AuthError, Result};
8use crate::tokens::AuthToken;
9use base64::Engine;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::fmt;
15use url::Url;
16
17#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
19pub enum OAuthProvider {
20 GitHub,
22
23 Google,
25
26 Microsoft,
28
29 Discord,
31
32 Twitter,
34
35 Facebook,
37
38 LinkedIn,
40
41 GitLab,
43
44 Custom {
46 name: String,
47 config: Box<OAuthProviderConfig>,
48 },
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53pub struct OAuthProviderConfig {
54 pub authorization_url: String,
56
57 pub token_url: String,
59
60 pub device_authorization_url: Option<String>,
62
63 pub userinfo_url: Option<String>,
65
66 pub revocation_url: Option<String>,
68
69 pub default_scopes: Vec<String>,
71
72 pub supports_pkce: bool,
74
75 pub supports_refresh: bool,
77
78 pub supports_device_flow: bool,
80
81 pub additional_params: HashMap<String, String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct DeviceAuthorizationResponse {
88 pub device_code: String,
90
91 pub user_code: String,
93
94 pub verification_uri: String,
96
97 pub verification_uri_complete: Option<String>,
99
100 pub interval: u64,
102
103 pub expires_in: u64,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct UserProfile {
110 pub id: Option<String>,
112
113 pub provider: Option<String>,
115
116 pub username: Option<String>,
118
119 pub name: Option<String>,
121
122 pub email: Option<String>,
124
125 pub email_verified: Option<bool>,
127
128 pub picture: Option<String>,
130
131 pub locale: Option<String>,
133
134 pub additional_data: HashMap<String, serde_json::Value>,
136}
137
138#[cfg(feature = "postgres-storage")]
139use sqlx::{Decode, Postgres, Type, postgres::PgValueRef};
140
141#[cfg(feature = "postgres-storage")]
142impl<'r> Decode<'r, Postgres> for UserProfile {
143 fn decode(value: PgValueRef<'r>) -> std::result::Result<Self, sqlx::error::BoxDynError> {
144 let json: serde_json::Value = <serde_json::Value as Decode<Postgres>>::decode(value)?;
145 serde_json::from_value(json).map_err(|e| Box::new(e) as sqlx::error::BoxDynError)
146 }
147}
148
149#[cfg(feature = "postgres-storage")]
150impl Type<Postgres> for UserProfile {
151 fn type_info() -> sqlx::postgres::PgTypeInfo {
152 <serde_json::Value as Type<Postgres>>::type_info()
153 }
154 fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
155 <serde_json::Value as Type<Postgres>>::compatible(ty)
156 }
157}
158
159impl UserProfile {
160 pub fn new() -> Self {
162 Self {
163 id: None,
164 provider: None,
165 username: None,
166 name: None,
167 email: None,
168 email_verified: None,
169 picture: None,
170 locale: None,
171 additional_data: HashMap::new(),
172 }
173 }
174
175 pub fn with_id(mut self, id: impl Into<String>) -> Self {
177 self.id = Some(id.into());
178 self
179 }
180
181 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
183 self.provider = Some(provider.into());
184 self
185 }
186
187 pub fn with_username(mut self, username: Option<impl Into<String>>) -> Self {
189 self.username = username.map(Into::into);
190 self
191 }
192
193 pub fn with_name(mut self, name: Option<impl Into<String>>) -> Self {
195 self.name = name.map(Into::into);
196 self
197 }
198
199 pub fn with_email(mut self, email: Option<impl Into<String>>) -> Self {
201 self.email = email.map(Into::into);
202 self
203 }
204
205 pub fn with_email_verified(mut self, verified: bool) -> Self {
207 self.email_verified = Some(verified);
208 self
209 }
210
211 pub fn with_picture(mut self, picture: Option<impl Into<String>>) -> Self {
213 self.picture = picture.map(Into::into);
214 self
215 }
216
217 pub fn with_locale(mut self, locale: Option<impl Into<String>>) -> Self {
219 self.locale = locale.map(Into::into);
220 self
221 }
222
223 pub fn with_additional_data(
225 mut self,
226 key: impl Into<String>,
227 value: serde_json::Value,
228 ) -> Self {
229 self.additional_data.insert(key.into(), value);
230 self
231 }
232
233 pub fn from_token_response(
235 token: &OAuthTokenResponse,
236 provider: &OAuthProvider,
237 ) -> Option<Self> {
238 if let Some(id_token_value) = token.additional_fields.get("id_token")
240 && let Some(id_token) = id_token_value.as_str()
241 && let Ok(profile) = Self::from_id_token(id_token)
242 {
243 return Some(profile.with_provider(provider.to_string()));
244 }
245 None
246 }
247
248 pub fn from_id_token(id_token: &str) -> Result<Self> {
250 let parts: Vec<&str> = id_token.split('.').collect();
252 if parts.len() != 3 {
253 return Err(AuthError::validation("Invalid JWT format"));
254 }
255
256 let payload = parts[1];
258 let padding_len = payload.len() % 4;
259 let padded_payload = if padding_len > 0 {
260 format!("{}{}", payload, "=".repeat(4 - padding_len))
261 } else {
262 payload.to_string()
263 };
264
265 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
267 .decode(&padded_payload)
268 .map_err(|e| AuthError::validation(format!("Failed to decode JWT: {}", e)))?;
269
270 let json: Value = serde_json::from_slice(&decoded)
272 .map_err(|e| AuthError::validation(format!("Failed to parse JWT payload: {}", e)))?;
273
274 let mut profile = Self::new();
276
277 if let Some(sub) = json.get("sub").and_then(|v| v.as_str()) {
279 profile = profile.with_id(sub);
280 } else if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
281 profile = profile.with_id(id);
282 } else {
283 return Err(AuthError::validation("JWT missing subject claim"));
284 }
285
286 if let Some(name) = json.get("name").and_then(|v| v.as_str()) {
288 profile = profile.with_name(Some(name));
289 }
290
291 if let Some(email) = json.get("email").and_then(|v| v.as_str()) {
292 profile = profile.with_email(Some(email));
293 }
294
295 if let Some(verified) = json.get("email_verified").and_then(|v| v.as_bool()) {
296 profile = profile.with_email_verified(verified);
297 }
298
299 if let Some(preferred_username) = json.get("preferred_username").and_then(|v| v.as_str()) {
300 profile = profile.with_username(Some(preferred_username));
301 }
302
303 if let Some(picture) = json.get("picture").and_then(|v| v.as_str()) {
304 profile = profile.with_picture(Some(picture));
305 }
306
307 if let Some(locale) = json.get("locale").and_then(|v| v.as_str()) {
308 profile = profile.with_locale(Some(locale));
309 }
310
311 profile = profile.with_additional_data("id_token_claims", json);
313
314 Ok(profile)
315 }
316
317 pub fn to_auth_token(&self, access_token: String) -> AuthToken {
319 let user_id = self.id.as_deref().unwrap_or("unknown").to_string();
320 let auth_method = self.provider.as_deref().unwrap_or("oauth").to_string();
321 let expires_in = std::time::Duration::from_secs(3600); let mut token = AuthToken::new(user_id.clone(), access_token, expires_in, auth_method);
324 token.subject = self.id.clone();
325 token.issuer = self.provider.clone();
326 token.user_profile = Some(self.clone());
327 token
328 }
329
330 pub fn has_id(&self) -> bool {
332 self.id.is_some()
333 }
334
335 pub fn display_name(&self) -> Option<&str> {
337 self.name.as_deref().or(self.username.as_deref())
338 }
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct OAuthTokenResponse {
344 pub access_token: String,
346
347 pub token_type: String,
349
350 pub expires_in: Option<u64>,
352
353 pub refresh_token: Option<String>,
355
356 pub scope: Option<String>,
358
359 #[serde(flatten)]
361 pub additional_fields: HashMap<String, serde_json::Value>,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct OAuthUserInfo {
367 pub id: String,
369
370 pub username: Option<String>,
372
373 pub name: Option<String>,
375
376 pub email: Option<String>,
378
379 pub email_verified: Option<bool>,
381
382 pub picture: Option<String>,
384
385 pub locale: Option<String>,
387
388 #[serde(flatten)]
390 pub additional_fields: HashMap<String, serde_json::Value>,
391}
392
393impl OAuthProvider {
394 pub fn config(&self) -> OAuthProviderConfig {
396 match self {
397 Self::GitHub => OAuthProviderConfig {
398 authorization_url: "https://github.com/login/oauth/authorize".to_string(),
399 token_url: "https://github.com/login/oauth/access_token".to_string(),
400 device_authorization_url: Some("https://github.com/login/device/code".to_string()),
401 userinfo_url: Some("https://api.github.com/user".to_string()),
402 revocation_url: None,
403 default_scopes: vec!["user:email".to_string()],
404 supports_pkce: true,
405 supports_refresh: false,
406 supports_device_flow: true,
407 additional_params: HashMap::new(),
408 },
409
410 Self::Google => OAuthProviderConfig {
411 authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
412 token_url: "https://oauth2.googleapis.com/token".to_string(),
413 device_authorization_url: Some(
414 "https://oauth2.googleapis.com/device/code".to_string(),
415 ),
416 userinfo_url: Some("https://www.googleapis.com/oauth2/v2/userinfo".to_string()),
417 revocation_url: Some("https://oauth2.googleapis.com/revoke".to_string()),
418 default_scopes: vec![
419 "openid".to_string(),
420 "profile".to_string(),
421 "email".to_string(),
422 ],
423 supports_pkce: true,
424 supports_refresh: true,
425 supports_device_flow: true,
426 additional_params: HashMap::new(),
427 },
428
429 Self::Microsoft => OAuthProviderConfig {
430 authorization_url: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
431 .to_string(),
432 token_url: "https://login.microsoftonline.com/common/oauth2/v2.0/token".to_string(),
433 device_authorization_url: Some(
434 "https://login.microsoftonline.com/common/oauth2/v2.0/devicecode".to_string(),
435 ),
436 userinfo_url: Some("https://graph.microsoft.com/v1.0/me".to_string()),
437 revocation_url: None,
438 default_scopes: vec![
439 "openid".to_string(),
440 "profile".to_string(),
441 "email".to_string(),
442 ],
443 supports_pkce: true,
444 supports_refresh: true,
445 supports_device_flow: true,
446 additional_params: HashMap::new(),
447 },
448
449 Self::Discord => OAuthProviderConfig {
450 authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
451 token_url: "https://discord.com/api/oauth2/token".to_string(),
452 device_authorization_url: None,
453 userinfo_url: Some("https://discord.com/api/users/@me".to_string()),
454 revocation_url: Some("https://discord.com/api/oauth2/token/revoke".to_string()),
455 default_scopes: vec!["identify".to_string(), "email".to_string()],
456 supports_pkce: false,
457 supports_refresh: true,
458 supports_device_flow: false,
459 additional_params: HashMap::new(),
460 },
461
462 Self::Twitter => OAuthProviderConfig {
463 authorization_url: "https://twitter.com/i/oauth2/authorize".to_string(),
464 token_url: "https://api.twitter.com/2/oauth2/token".to_string(),
465 device_authorization_url: None,
466 userinfo_url: Some("https://api.twitter.com/2/users/me".to_string()),
467 revocation_url: Some("https://api.twitter.com/2/oauth2/revoke".to_string()),
468 default_scopes: vec!["tweet.read".to_string(), "users.read".to_string()],
469 supports_pkce: true,
470 supports_refresh: true,
471 supports_device_flow: false,
472 additional_params: HashMap::new(),
473 },
474
475 Self::Facebook => OAuthProviderConfig {
476 authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
477 token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
478 device_authorization_url: None,
479 userinfo_url: Some("https://graph.facebook.com/me".to_string()),
480 revocation_url: None,
481 default_scopes: vec!["email".to_string(), "public_profile".to_string()],
482 supports_pkce: false,
483 supports_refresh: false,
484 supports_device_flow: false,
485 additional_params: HashMap::new(),
486 },
487
488 Self::LinkedIn => OAuthProviderConfig {
489 authorization_url: "https://www.linkedin.com/oauth/v2/authorization".to_string(),
490 token_url: "https://www.linkedin.com/oauth/v2/accessToken".to_string(),
491 device_authorization_url: None,
492 userinfo_url: Some("https://api.linkedin.com/v2/me".to_string()),
493 revocation_url: None,
494 default_scopes: vec!["r_liteprofile".to_string(), "r_emailaddress".to_string()],
495 supports_pkce: false,
496 supports_refresh: true,
497 supports_device_flow: false,
498 additional_params: HashMap::new(),
499 },
500
501 Self::GitLab => OAuthProviderConfig {
502 authorization_url: "https://gitlab.com/oauth/authorize".to_string(),
503 token_url: "https://gitlab.com/oauth/token".to_string(),
504 device_authorization_url: None,
505 userinfo_url: Some("https://gitlab.com/api/v4/user".to_string()),
506 revocation_url: Some("https://gitlab.com/oauth/revoke".to_string()),
507 default_scopes: vec!["read_user".to_string()],
508 supports_pkce: true,
509 supports_refresh: true,
510 supports_device_flow: false,
511 additional_params: HashMap::new(),
512 },
513
514 Self::Custom { config, .. } => *config.clone(),
515 }
516 }
517
518 pub fn name(&self) -> &str {
520 match self {
521 Self::GitHub => "github",
522 Self::Google => "google",
523 Self::Microsoft => "microsoft",
524 Self::Discord => "discord",
525 Self::Twitter => "twitter",
526 Self::Facebook => "facebook",
527 Self::LinkedIn => "linkedin",
528 Self::GitLab => "gitlab",
529 Self::Custom { name, .. } => name,
530 }
531 }
532
533 pub fn custom(name: impl Into<String>, config: OAuthProviderConfig) -> Self {
535 Self::Custom {
536 name: name.into(),
537 config: Box::new(config),
538 }
539 }
540
541 pub fn build_authorization_url(
543 &self,
544 client_id: &str,
545 redirect_uri: &str,
546 state: &str,
547 scopes: Option<&[String]>,
548 code_challenge: Option<&str>,
549 ) -> Result<String> {
550 let config = self.config();
551 let mut url = Url::parse(&config.authorization_url)
552 .map_err(|e| AuthError::config(format!("Invalid authorization URL: {e}")))?;
553
554 let scopes = scopes.unwrap_or(&config.default_scopes);
555
556 {
557 let mut query = url.query_pairs_mut();
558 query.append_pair("client_id", client_id);
559 query.append_pair("redirect_uri", redirect_uri);
560 query.append_pair("response_type", "code");
561 query.append_pair("state", state);
562
563 if !scopes.is_empty() {
564 query.append_pair("scope", &scopes.join(" "));
565 }
566
567 if config.supports_pkce
569 && let Some(challenge) = code_challenge
570 {
571 query.append_pair("code_challenge", challenge);
572 query.append_pair("code_challenge_method", "S256");
573 }
574
575 for (key, value) in &config.additional_params {
577 query.append_pair(key, value);
578 }
579 }
580
581 Ok(url.to_string())
582 }
583
584 pub async fn exchange_code(
586 &self,
587 client_id: &str,
588 client_secret: &str,
589 authorization_code: &str,
590 redirect_uri: &str,
591 code_verifier: Option<&str>,
592 ) -> Result<OAuthTokenResponse> {
593 let config = self.config();
594 let client = reqwest::Client::new();
595
596 let mut params = vec![
597 ("grant_type", "authorization_code"),
598 ("client_id", client_id),
599 ("client_secret", client_secret),
600 ("code", authorization_code),
601 ("redirect_uri", redirect_uri),
602 ];
603
604 if let Some(verifier) = code_verifier {
606 params.push(("code_verifier", verifier));
607 }
608
609 let response = client.post(&config.token_url).form(¶ms).send().await?;
610
611 if !response.status().is_success() {
612 let error_text = response.text().await.unwrap_or_default();
613 return Err(AuthError::auth_method(
614 self.name(),
615 format!("Token exchange failed: {error_text}"),
616 ));
617 }
618
619 let token_response: OAuthTokenResponse = response.json().await?;
620 Ok(token_response)
621 }
622
623 pub async fn refresh_token(
625 &self,
626 client_id: &str,
627 client_secret: &str,
628 refresh_token: &str,
629 ) -> Result<OAuthTokenResponse> {
630 let config = self.config();
631
632 if !config.supports_refresh {
633 return Err(AuthError::auth_method(
634 self.name(),
635 "Provider does not support token refresh".to_string(),
636 ));
637 }
638
639 let client = reqwest::Client::new();
640
641 let params = vec![
642 ("grant_type", "refresh_token"),
643 ("client_id", client_id),
644 ("client_secret", client_secret),
645 ("refresh_token", refresh_token),
646 ];
647
648 let response = client.post(&config.token_url).form(¶ms).send().await?;
649
650 if !response.status().is_success() {
651 let error_text = response.text().await.unwrap_or_default();
652 return Err(AuthError::auth_method(
653 self.name(),
654 format!("Token refresh failed: {error_text}"),
655 ));
656 }
657
658 let token_response: OAuthTokenResponse = response.json().await?;
659 Ok(token_response)
660 }
661
662 pub async fn get_user_info(&self, access_token: &str) -> Result<OAuthUserInfo> {
664 let config = self.config();
665
666 let userinfo_url = config.userinfo_url.ok_or_else(|| {
667 AuthError::auth_method(
668 self.name(),
669 "Provider does not support user info endpoint".to_string(),
670 )
671 })?;
672
673 let client = reqwest::Client::new();
674 let response = client
675 .get(&userinfo_url)
676 .bearer_auth(access_token)
677 .send()
678 .await?;
679
680 if !response.status().is_success() {
681 let error_text = response.text().await.unwrap_or_default();
682 return Err(AuthError::auth_method(
683 self.name(),
684 format!("User info request failed: {error_text}"),
685 ));
686 }
687
688 let user_data: serde_json::Value = response.json().await?;
689
690 let user_info = self.parse_user_info(user_data)?;
692 Ok(user_info)
693 }
694
695 fn parse_user_info(&self, data: serde_json::Value) -> Result<OAuthUserInfo> {
697 let mut additional_fields = HashMap::new();
698
699 let user_info = match self {
700 Self::GitHub => {
701 let id = data["id"]
702 .as_u64()
703 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
704 .to_string();
705
706 OAuthUserInfo {
707 id,
708 username: data["login"].as_str().map(|s| s.to_string()),
709 email: data["email"].as_str().map(|s| s.to_string()),
710 name: data["name"].as_str().map(|s| s.to_string()),
711 picture: data["avatar_url"].as_str().map(|s| s.to_string()),
712 email_verified: None, locale: None,
714 additional_fields,
715 }
716 }
717
718 Self::Google => {
719 let id = data["id"]
720 .as_str()
721 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
722 .to_string();
723
724 OAuthUserInfo {
725 id,
726 username: None, email: data["email"].as_str().map(|s| s.to_string()),
728 name: data["name"].as_str().map(|s| s.to_string()),
729 picture: data["picture"].as_str().map(|s| s.to_string()),
730 email_verified: data["verified_email"].as_bool(),
731 locale: data["locale"].as_str().map(|s| s.to_string()),
732 additional_fields,
733 }
734 }
735
736 _ => {
738 let id = data["id"]
740 .as_str()
741 .or_else(|| data["sub"].as_str())
742 .or_else(|| data["user_id"].as_str())
743 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
744 .to_string();
745
746 if let serde_json::Value::Object(map) = data {
748 additional_fields = map.into_iter().collect();
749 }
750
751 OAuthUserInfo {
752 id,
753 username: additional_fields
754 .get("username")
755 .or_else(|| additional_fields.get("login"))
756 .and_then(|v| v.as_str())
757 .map(|s| s.to_string()),
758 email: additional_fields
759 .get("email")
760 .and_then(|v| v.as_str())
761 .map(|s| s.to_string()),
762 name: additional_fields
763 .get("name")
764 .or_else(|| additional_fields.get("display_name"))
765 .and_then(|v| v.as_str())
766 .map(|s| s.to_string()),
767 picture: additional_fields
768 .get("avatar_url")
769 .or_else(|| additional_fields.get("picture"))
770 .and_then(|v| v.as_str())
771 .map(|s| s.to_string()),
772 email_verified: additional_fields
773 .get("email_verified")
774 .and_then(|v| v.as_bool()),
775 locale: additional_fields
776 .get("locale")
777 .and_then(|v| v.as_str())
778 .map(|s| s.to_string()),
779 additional_fields,
780 }
781 }
782 };
783
784 Ok(user_info)
785 }
786
787 pub async fn revoke_token(&self, access_token: &str) -> Result<()> {
789 let config = self.config();
790
791 let revocation_url = config.revocation_url.ok_or_else(|| {
792 AuthError::auth_method(
793 self.name(),
794 "Provider does not support token revocation".to_string(),
795 )
796 })?;
797
798 let client = reqwest::Client::new();
799 let response = client
800 .post(&revocation_url)
801 .form(&[("token", access_token)])
802 .send()
803 .await?;
804
805 if !response.status().is_success() {
806 let error_text = response.text().await.unwrap_or_default();
807 return Err(AuthError::auth_method(
808 self.name(),
809 format!("Token revocation failed: {error_text}"),
810 ));
811 }
812
813 Ok(())
814 }
815
816 pub async fn device_authorization(
818 &self,
819 client_id: &str,
820 scope: Option<&[String]>,
821 ) -> Result<DeviceAuthorizationResponse> {
822 let config = self.config();
823
824 if !config.supports_device_flow {
825 return Err(AuthError::auth_method(
826 self.name(),
827 "Provider does not support device authorization flow".to_string(),
828 ));
829 }
830
831 let client = reqwest::Client::new();
832
833 let scope_string = scope.unwrap_or(&config.default_scopes).join(" ");
834 let params = vec![("client_id", client_id), ("scope", scope_string.as_str())];
835
836 let response = client
837 .post(config.device_authorization_url.as_deref().unwrap())
838 .form(¶ms)
839 .send()
840 .await?;
841
842 if !response.status().is_success() {
843 let error_text = response.text().await.unwrap_or_default();
844 return Err(AuthError::auth_method(
845 self.name(),
846 format!("Device authorization request failed: {error_text}"),
847 ));
848 }
849
850 let device_response: DeviceAuthorizationResponse = response.json().await?;
851 Ok(device_response)
852 }
853
854 pub async fn poll_device_code(
856 &self,
857 client_id: &str,
858 device_code: &str,
859 _interval: Option<u64>,
860 ) -> Result<OAuthTokenResponse> {
861 let config = self.config();
862
863 if !config.supports_device_flow {
864 return Err(AuthError::auth_method(
865 self.name(),
866 "Provider does not support device authorization flow".to_string(),
867 ));
868 }
869
870 let client = reqwest::Client::new();
871
872 let params = vec![
873 ("client_id", client_id),
874 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
875 ("device_code", device_code),
876 ];
877
878 let response = client.post(&config.token_url).form(¶ms).send().await?;
879
880 if !response.status().is_success() {
881 let error_text = response.text().await.unwrap_or_default();
882 return Err(AuthError::auth_method(
883 self.name(),
884 format!("Token request failed: {error_text}"),
885 ));
886 }
887
888 let token_response: OAuthTokenResponse = response.json().await?;
889 Ok(token_response)
890 }
891}
892
893pub fn generate_state() -> String {
895 let mut bytes = [0u8; 32];
896 use rand::RngCore;
897 rand::thread_rng().fill_bytes(&mut bytes);
898 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
899}
900
901pub fn generate_pkce() -> (String, String) {
903 use rand::RngCore;
904 use ring::digest;
905
906 let mut rng = rand::thread_rng();
908 let mut bytes = [0u8; 96]; rng.fill_bytes(&mut bytes);
910 let code_verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
911
912 let digest = digest::digest(&digest::SHA256, code_verifier.as_bytes());
914 let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest.as_ref());
915
916 (code_verifier, code_challenge)
917}
918
919pub struct ProfileExtractor {
921 client: Client,
922}
923
924impl ProfileExtractor {
925 pub fn new() -> Self {
927 Self {
928 client: Client::new(),
929 }
930 }
931
932 pub async fn extract_profile(
934 &self,
935 token: &AuthToken,
936 provider: &OAuthProvider,
937 ) -> Result<UserProfile> {
938 match provider {
939 OAuthProvider::GitHub => self.extract_github_profile(token).await,
940 OAuthProvider::Google => self.extract_google_profile(token).await,
941 OAuthProvider::Microsoft => self.extract_microsoft_profile(token).await,
942 OAuthProvider::Discord => self.extract_discord_profile(token).await,
943 OAuthProvider::GitLab => self.extract_gitlab_profile(token).await,
944 OAuthProvider::Custom { name, config } => {
945 self.extract_custom_profile(token, name, config).await
946 }
947 _ => Err(AuthError::UnsupportedProvider(format!(
948 "Profile extraction not supported for {:?}",
949 provider
950 ))),
951 }
952 }
953
954 async fn extract_github_profile(&self, token: &AuthToken) -> Result<UserProfile> {
956 let response = self
957 .client
958 .get("https://api.github.com/user")
959 .bearer_auth(&token.access_token)
960 .send()
961 .await
962 .map_err(|e| AuthError::NetworkError(e.to_string()))?;
963
964 let json: Value = response
965 .json()
966 .await
967 .map_err(|e| AuthError::ParseError(e.to_string()))?;
968
969 let mut profile = UserProfile::new();
970 profile = profile.with_id(json["id"].as_u64().unwrap_or(0).to_string());
971 profile = profile.with_provider("github".to_string());
972
973 if let Some(login) = json["login"].as_str() {
974 profile.username = Some(login.to_string());
975 }
976
977 if let Some(name) = json["name"].as_str() {
978 profile.name = Some(name.to_string());
979 }
980
981 if let Some(email) = json["email"].as_str() {
982 profile.email = Some(email.to_string());
983 }
984
985 if let Some(avatar_url) = json["avatar_url"].as_str() {
986 profile.picture = Some(avatar_url.to_string());
987 }
988
989 if let Some(company) = json["company"].as_str() {
991 profile
992 .additional_data
993 .insert("company".to_string(), Value::String(company.to_string()));
994 }
995
996 if let Some(blog) = json["blog"].as_str() {
997 profile
998 .additional_data
999 .insert("blog".to_string(), Value::String(blog.to_string()));
1000 }
1001
1002 if let Some(bio) = json["bio"].as_str() {
1003 profile
1004 .additional_data
1005 .insert("bio".to_string(), Value::String(bio.to_string()));
1006 }
1007
1008 Ok(profile)
1009 }
1010
1011 async fn extract_google_profile(&self, token: &AuthToken) -> Result<UserProfile> {
1013 let response = self
1014 .client
1015 .get("https://www.googleapis.com/oauth2/v2/userinfo")
1016 .bearer_auth(&token.access_token)
1017 .send()
1018 .await
1019 .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1020
1021 let json: Value = response
1022 .json()
1023 .await
1024 .map_err(|e| AuthError::ParseError(e.to_string()))?;
1025
1026 let mut profile = UserProfile::new();
1027 profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1028 profile = profile.with_provider("google".to_string());
1029
1030 if let Some(name) = json["name"].as_str() {
1031 profile.name = Some(name.to_string());
1032 }
1033
1034 if let Some(email) = json["email"].as_str() {
1035 profile.email = Some(email.to_string());
1036 }
1037
1038 if let Some(verified) = json["verified_email"].as_bool() {
1039 profile.email_verified = Some(verified);
1040 }
1041
1042 if let Some(picture) = json["picture"].as_str() {
1043 profile.picture = Some(picture.to_string());
1044 }
1045
1046 if let Some(locale) = json["locale"].as_str() {
1047 profile.locale = Some(locale.to_string());
1048 }
1049
1050 Ok(profile)
1051 }
1052
1053 async fn extract_microsoft_profile(&self, token: &AuthToken) -> Result<UserProfile> {
1055 let response = self
1056 .client
1057 .get("https://graph.microsoft.com/v1.0/me")
1058 .bearer_auth(&token.access_token)
1059 .send()
1060 .await
1061 .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1062
1063 let json: Value = response
1064 .json()
1065 .await
1066 .map_err(|e| AuthError::ParseError(e.to_string()))?;
1067
1068 let mut profile = UserProfile::new();
1069 profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1070 profile = profile.with_provider("microsoft".to_string());
1071
1072 if let Some(display_name) = json["displayName"].as_str() {
1073 profile.name = Some(display_name.to_string());
1074 }
1075
1076 if let Some(user_principal_name) = json["userPrincipalName"].as_str() {
1077 profile.username = Some(user_principal_name.to_string());
1078 }
1079
1080 if let Some(mail) = json["mail"].as_str() {
1081 profile.email = Some(mail.to_string());
1082 }
1083
1084 if let Some(preferred_language) = json["preferredLanguage"].as_str() {
1085 profile.locale = Some(preferred_language.to_string());
1086 }
1087
1088 if let Some(job_title) = json["jobTitle"].as_str() {
1090 profile
1091 .additional_data
1092 .insert("jobTitle".to_string(), Value::String(job_title.to_string()));
1093 }
1094
1095 if let Some(office_location) = json["officeLocation"].as_str() {
1096 profile.additional_data.insert(
1097 "officeLocation".to_string(),
1098 Value::String(office_location.to_string()),
1099 );
1100 }
1101
1102 Ok(profile)
1103 }
1104
1105 async fn extract_discord_profile(&self, token: &AuthToken) -> Result<UserProfile> {
1107 let response = self
1108 .client
1109 .get("https://discord.com/api/users/@me")
1110 .bearer_auth(&token.access_token)
1111 .send()
1112 .await
1113 .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1114
1115 let json: Value = response
1116 .json()
1117 .await
1118 .map_err(|e| AuthError::ParseError(e.to_string()))?;
1119
1120 let mut profile = UserProfile::new();
1121 profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1122 profile = profile.with_provider("discord".to_string());
1123
1124 if let Some(username) = json["username"].as_str() {
1125 profile.username = Some(username.to_string());
1126 }
1127
1128 if let Some(discriminator) = json["discriminator"].as_str() {
1129 profile.name = Some(format!(
1130 "{}#{}",
1131 json["username"].as_str().unwrap_or(""),
1132 discriminator
1133 ));
1134 }
1135
1136 if let Some(email) = json["email"].as_str() {
1137 profile.email = Some(email.to_string());
1138 }
1139
1140 if let Some(verified) = json["verified"].as_bool() {
1141 profile.email_verified = Some(verified);
1142 }
1143
1144 if let Some(avatar) = json["avatar"].as_str() {
1145 let user_id = json["id"].as_str().unwrap_or("");
1146 profile.picture = Some(format!(
1147 "https://cdn.discordapp.com/avatars/{}/{}.png",
1148 user_id, avatar
1149 ));
1150 }
1151
1152 if let Some(locale) = json["locale"].as_str() {
1153 profile.locale = Some(locale.to_string());
1154 }
1155
1156 Ok(profile)
1157 }
1158
1159 async fn extract_gitlab_profile(&self, token: &AuthToken) -> Result<UserProfile> {
1161 let response = self
1162 .client
1163 .get("https://gitlab.com/api/v4/user")
1164 .bearer_auth(&token.access_token)
1165 .send()
1166 .await
1167 .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1168
1169 let json: Value = response
1170 .json()
1171 .await
1172 .map_err(|e| AuthError::ParseError(e.to_string()))?;
1173
1174 let mut profile = UserProfile::new();
1175 profile = profile.with_id(json["id"].as_u64().unwrap_or(0).to_string());
1176 profile = profile.with_provider("gitlab".to_string());
1177
1178 if let Some(username) = json["username"].as_str() {
1179 profile.username = Some(username.to_string());
1180 }
1181
1182 if let Some(name) = json["name"].as_str() {
1183 profile.name = Some(name.to_string());
1184 }
1185
1186 if let Some(email) = json["email"].as_str() {
1187 profile.email = Some(email.to_string());
1188 }
1189
1190 if let Some(avatar_url) = json["avatar_url"].as_str() {
1191 profile.picture = Some(avatar_url.to_string());
1192 }
1193
1194 if let Some(web_url) = json["web_url"].as_str() {
1196 profile
1197 .additional_data
1198 .insert("web_url".to_string(), Value::String(web_url.to_string()));
1199 }
1200
1201 if let Some(bio) = json["bio"].as_str() {
1202 profile
1203 .additional_data
1204 .insert("bio".to_string(), Value::String(bio.to_string()));
1205 }
1206
1207 Ok(profile)
1208 }
1209
1210 async fn extract_custom_profile(
1212 &self,
1213 token: &AuthToken,
1214 provider_name: &str,
1215 config: &OAuthProviderConfig,
1216 ) -> Result<UserProfile> {
1217 if let Some(user_info_url) = &config.userinfo_url {
1218 let response = self
1219 .client
1220 .get(user_info_url)
1221 .bearer_auth(&token.access_token)
1222 .send()
1223 .await
1224 .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1225
1226 let json: Value = response
1227 .json()
1228 .await
1229 .map_err(|e| AuthError::ParseError(e.to_string()))?;
1230
1231 let mut profile = UserProfile::new();
1232 profile = profile.with_id(
1233 json["id"]
1234 .as_str()
1235 .or_else(|| json["sub"].as_str())
1236 .unwrap_or("")
1237 .to_string(),
1238 );
1239 profile = profile.with_provider(provider_name.to_string());
1240
1241 if let Some(username) = json["username"].as_str().or_else(|| json["login"].as_str()) {
1243 profile.username = Some(username.to_string());
1244 }
1245
1246 if let Some(name) = json["name"]
1247 .as_str()
1248 .or_else(|| json["display_name"].as_str())
1249 {
1250 profile.name = Some(name.to_string());
1251 }
1252
1253 if let Some(email) = json["email"].as_str() {
1254 profile.email = Some(email.to_string());
1255 }
1256
1257 if let Some(verified) = json["email_verified"]
1258 .as_bool()
1259 .or_else(|| json["verified"].as_bool())
1260 {
1261 profile.email_verified = Some(verified);
1262 }
1263
1264 if let Some(picture) = json["picture"]
1265 .as_str()
1266 .or_else(|| json["avatar_url"].as_str())
1267 {
1268 profile.picture = Some(picture.to_string());
1269 }
1270
1271 if let Some(locale) = json["locale"].as_str().or_else(|| json["lang"].as_str()) {
1272 profile.locale = Some(locale.to_string());
1273 }
1274
1275 for (key, value) in json.as_object().unwrap_or(&serde_json::Map::new()) {
1277 if ![
1278 "id",
1279 "sub",
1280 "username",
1281 "login",
1282 "name",
1283 "display_name",
1284 "email",
1285 "email_verified",
1286 "verified",
1287 "picture",
1288 "avatar_url",
1289 "locale",
1290 "lang",
1291 ]
1292 .contains(&key.as_str())
1293 {
1294 profile.additional_data.insert(key.clone(), value.clone());
1295 }
1296 }
1297
1298 Ok(profile)
1299 } else {
1300 Err(AuthError::ConfigurationError(
1301 "Custom provider requires user_info_url".to_string(),
1302 ))
1303 }
1304 }
1305}
1306
1307impl Default for ProfileExtractor {
1308 fn default() -> Self {
1309 Self::new()
1310 }
1311}
1312
1313impl fmt::Display for OAuthProvider {
1314 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1315 match self {
1316 OAuthProvider::GitHub => write!(f, "github"),
1317 OAuthProvider::Google => write!(f, "google"),
1318 OAuthProvider::Microsoft => write!(f, "microsoft"),
1319 OAuthProvider::Discord => write!(f, "discord"),
1320 OAuthProvider::Twitter => write!(f, "twitter"),
1321 OAuthProvider::Facebook => write!(f, "facebook"),
1322 OAuthProvider::LinkedIn => write!(f, "linkedin"),
1323 OAuthProvider::GitLab => write!(f, "gitlab"),
1324 OAuthProvider::Custom { name, .. } => write!(f, "{}", name),
1325 }
1326 }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331 use super::*;
1332
1333 #[test]
1334 fn test_provider_config() {
1335 let github = OAuthProvider::GitHub;
1336 let config = github.config();
1337
1338 assert_eq!(
1339 config.authorization_url,
1340 "https://github.com/login/oauth/authorize"
1341 );
1342 assert_eq!(
1343 config.token_url,
1344 "https://github.com/login/oauth/access_token"
1345 );
1346 assert!(config.supports_pkce);
1347 }
1348
1349 #[test]
1350 fn test_authorization_url() {
1351 let github = OAuthProvider::GitHub;
1352 let url = github
1353 .build_authorization_url(
1354 "client123",
1355 "https://example.com/callback",
1356 "state123",
1357 None,
1358 Some("challenge123"),
1359 )
1360 .unwrap();
1361
1362 assert!(url.contains("client_id=client123"));
1363 assert!(url.contains("redirect_uri=https%3A%2F%2Fexample.com%2Fcallback"));
1364 assert!(url.contains("state=state123"));
1365 assert!(url.contains("code_challenge=challenge123"));
1366 }
1367
1368 #[test]
1369 fn test_generate_state() {
1370 let state1 = generate_state();
1371 let state2 = generate_state();
1372
1373 assert_eq!(state1.len(), 43);
1374 assert_eq!(state2.len(), 43);
1375 assert_ne!(state1, state2);
1376 }
1377
1378 #[test]
1379 fn test_generate_pkce() {
1380 let (verifier1, challenge1) = generate_pkce();
1381 let (verifier2, challenge2) = generate_pkce();
1382
1383 assert_eq!(verifier1.len(), 128);
1384 assert_eq!(verifier2.len(), 128);
1385 assert_ne!(verifier1, verifier2);
1386 assert_ne!(challenge1, challenge2);
1387 }
1388}
1389
1390