1use base64::Engine;
4use crate::errors::{AuthError, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use url::Url;
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub enum OAuthProvider {
12 GitHub,
14
15 Google,
17
18 Microsoft,
20
21 Discord,
23
24 Twitter,
26
27 Facebook,
29
30 LinkedIn,
32
33 GitLab,
35
36 Custom {
38 name: String,
39 config: Box<OAuthProviderConfig>,
40 },
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct OAuthProviderConfig {
46 pub authorization_url: String,
48
49 pub token_url: String,
51
52 pub device_authorization_url: Option<String>,
54
55 pub userinfo_url: Option<String>,
57
58 pub revocation_url: Option<String>,
60
61 pub default_scopes: Vec<String>,
63
64 pub supports_pkce: bool,
66
67 pub supports_refresh: bool,
69
70 pub supports_device_flow: bool,
72
73 pub additional_params: HashMap<String, String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct DeviceAuthorizationResponse {
80 pub device_code: String,
82
83 pub user_code: String,
85
86 pub verification_uri: String,
88
89 pub verification_uri_complete: Option<String>,
91
92 pub interval: u64,
94
95 pub expires_in: u64,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct UserProfile {
102 pub id: String,
104
105 pub provider: String,
107
108 pub username: Option<String>,
110
111 pub name: Option<String>,
113
114 pub email: Option<String>,
116
117 pub email_verified: Option<bool>,
119
120 pub picture: Option<String>,
122
123 pub locale: Option<String>,
125
126 pub additional_data: HashMap<String, serde_json::Value>,
128}
129
130impl UserProfile {
131 pub fn new(id: impl Into<String>, provider: impl Into<String>) -> Self {
133 Self {
134 id: id.into(),
135 provider: provider.into(),
136 username: None,
137 name: None,
138 email: None,
139 email_verified: None,
140 picture: None,
141 locale: None,
142 additional_data: HashMap::new(),
143 }
144 }
145
146 pub fn with_username(mut self, username: impl Into<String>) -> Self {
148 self.username = Some(username.into());
149 self
150 }
151
152 pub fn with_name(mut self, name: impl Into<String>) -> Self {
154 self.name = Some(name.into());
155 self
156 }
157
158 pub fn with_email(mut self, email: impl Into<String>) -> Self {
160 self.email = Some(email.into());
161 self
162 }
163
164 pub fn with_email_verified(mut self, verified: bool) -> Self {
166 self.email_verified = Some(verified);
167 self
168 }
169
170 pub fn with_picture(mut self, picture: impl Into<String>) -> Self {
172 self.picture = Some(picture.into());
173 self
174 }
175
176 pub fn with_locale(mut self, locale: impl Into<String>) -> Self {
178 self.locale = Some(locale.into());
179 self
180 }
181
182 pub fn with_additional_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
184 self.additional_data.insert(key.into(), value);
185 self
186 }
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct OAuthTokenResponse {
192 pub access_token: String,
194
195 pub token_type: String,
197
198 pub expires_in: Option<u64>,
200
201 pub refresh_token: Option<String>,
203
204 pub scope: Option<String>,
206
207 #[serde(flatten)]
209 pub additional_fields: HashMap<String, serde_json::Value>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct OAuthUserInfo {
215 pub id: String,
217
218 pub username: Option<String>,
220
221 pub name: Option<String>,
223
224 pub email: Option<String>,
226
227 pub email_verified: Option<bool>,
229
230 pub picture: Option<String>,
232
233 pub locale: Option<String>,
235
236 #[serde(flatten)]
238 pub additional_fields: HashMap<String, serde_json::Value>,
239}
240
241impl OAuthProvider {
242 pub fn config(&self) -> OAuthProviderConfig {
244 match self {
245 Self::GitHub => OAuthProviderConfig {
246 authorization_url: "https://github.com/login/oauth/authorize".to_string(),
247 token_url: "https://github.com/login/oauth/access_token".to_string(),
248 device_authorization_url: Some("https://github.com/login/device/code".to_string()),
249 userinfo_url: Some("https://api.github.com/user".to_string()),
250 revocation_url: None,
251 default_scopes: vec!["user:email".to_string()],
252 supports_pkce: true,
253 supports_refresh: false,
254 supports_device_flow: true,
255 additional_params: HashMap::new(),
256 },
257
258 Self::Google => OAuthProviderConfig {
259 authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
260 token_url: "https://oauth2.googleapis.com/token".to_string(),
261 device_authorization_url: Some("https://oauth2.googleapis.com/device/code".to_string()),
262 userinfo_url: Some("https://www.googleapis.com/oauth2/v2/userinfo".to_string()),
263 revocation_url: Some("https://oauth2.googleapis.com/revoke".to_string()),
264 default_scopes: vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
265 supports_pkce: true,
266 supports_refresh: true,
267 supports_device_flow: true,
268 additional_params: HashMap::new(),
269 },
270
271 Self::Microsoft => OAuthProviderConfig {
272 authorization_url: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize".to_string(),
273 token_url: "https://login.microsoftonline.com/common/oauth2/v2.0/token".to_string(),
274 device_authorization_url: Some("https://login.microsoftonline.com/common/oauth2/v2.0/devicecode".to_string()),
275 userinfo_url: Some("https://graph.microsoft.com/v1.0/me".to_string()),
276 revocation_url: None,
277 default_scopes: vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
278 supports_pkce: true,
279 supports_refresh: true,
280 supports_device_flow: true,
281 additional_params: HashMap::new(),
282 },
283
284 Self::Discord => OAuthProviderConfig {
285 authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
286 token_url: "https://discord.com/api/oauth2/token".to_string(),
287 device_authorization_url: None,
288 userinfo_url: Some("https://discord.com/api/users/@me".to_string()),
289 revocation_url: Some("https://discord.com/api/oauth2/token/revoke".to_string()),
290 default_scopes: vec!["identify".to_string(), "email".to_string()],
291 supports_pkce: false,
292 supports_refresh: true,
293 supports_device_flow: false,
294 additional_params: HashMap::new(),
295 },
296
297 Self::Twitter => OAuthProviderConfig {
298 authorization_url: "https://twitter.com/i/oauth2/authorize".to_string(),
299 token_url: "https://api.twitter.com/2/oauth2/token".to_string(),
300 device_authorization_url: None,
301 userinfo_url: Some("https://api.twitter.com/2/users/me".to_string()),
302 revocation_url: Some("https://api.twitter.com/2/oauth2/revoke".to_string()),
303 default_scopes: vec!["tweet.read".to_string(), "users.read".to_string()],
304 supports_pkce: true,
305 supports_refresh: true,
306 supports_device_flow: false,
307 additional_params: HashMap::new(),
308 },
309
310 Self::Facebook => OAuthProviderConfig {
311 authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
312 token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
313 device_authorization_url: None,
314 userinfo_url: Some("https://graph.facebook.com/me".to_string()),
315 revocation_url: None,
316 default_scopes: vec!["email".to_string(), "public_profile".to_string()],
317 supports_pkce: false,
318 supports_refresh: false,
319 supports_device_flow: false,
320 additional_params: HashMap::new(),
321 },
322
323 Self::LinkedIn => OAuthProviderConfig {
324 authorization_url: "https://www.linkedin.com/oauth/v2/authorization".to_string(),
325 token_url: "https://www.linkedin.com/oauth/v2/accessToken".to_string(),
326 device_authorization_url: None,
327 userinfo_url: Some("https://api.linkedin.com/v2/me".to_string()),
328 revocation_url: None,
329 default_scopes: vec!["r_liteprofile".to_string(), "r_emailaddress".to_string()],
330 supports_pkce: false,
331 supports_refresh: true,
332 supports_device_flow: false,
333 additional_params: HashMap::new(),
334 },
335
336 Self::GitLab => OAuthProviderConfig {
337 authorization_url: "https://gitlab.com/oauth/authorize".to_string(),
338 token_url: "https://gitlab.com/oauth/token".to_string(),
339 device_authorization_url: None,
340 userinfo_url: Some("https://gitlab.com/api/v4/user".to_string()),
341 revocation_url: Some("https://gitlab.com/oauth/revoke".to_string()),
342 default_scopes: vec!["read_user".to_string()],
343 supports_pkce: true,
344 supports_refresh: true,
345 supports_device_flow: false,
346 additional_params: HashMap::new(),
347 },
348
349 Self::Custom { config, .. } => *config.clone(),
350 }
351 }
352
353 pub fn name(&self) -> &str {
355 match self {
356 Self::GitHub => "github",
357 Self::Google => "google",
358 Self::Microsoft => "microsoft",
359 Self::Discord => "discord",
360 Self::Twitter => "twitter",
361 Self::Facebook => "facebook",
362 Self::LinkedIn => "linkedin",
363 Self::GitLab => "gitlab",
364 Self::Custom { name, .. } => name,
365 }
366 }
367
368 pub fn custom(name: impl Into<String>, config: OAuthProviderConfig) -> Self {
370 Self::Custom {
371 name: name.into(),
372 config: Box::new(config),
373 }
374 }
375
376 pub fn build_authorization_url(
378 &self,
379 client_id: &str,
380 redirect_uri: &str,
381 state: &str,
382 scopes: Option<&[String]>,
383 code_challenge: Option<&str>,
384 ) -> Result<String> {
385 let config = self.config();
386 let mut url = Url::parse(&config.authorization_url)
387 .map_err(|e| AuthError::config(format!("Invalid authorization URL: {e}")))?;
388
389 let scopes = scopes.unwrap_or(&config.default_scopes);
390
391 {
392 let mut query = url.query_pairs_mut();
393 query.append_pair("client_id", client_id);
394 query.append_pair("redirect_uri", redirect_uri);
395 query.append_pair("response_type", "code");
396 query.append_pair("state", state);
397
398 if !scopes.is_empty() {
399 query.append_pair("scope", &scopes.join(" "));
400 }
401
402 if config.supports_pkce {
404 if let Some(challenge) = code_challenge {
405 query.append_pair("code_challenge", challenge);
406 query.append_pair("code_challenge_method", "S256");
407 }
408 }
409
410 for (key, value) in &config.additional_params {
412 query.append_pair(key, value);
413 }
414 }
415
416 Ok(url.to_string())
417 }
418
419 pub async fn exchange_code(
421 &self,
422 client_id: &str,
423 client_secret: &str,
424 authorization_code: &str,
425 redirect_uri: &str,
426 code_verifier: Option<&str>,
427 ) -> Result<OAuthTokenResponse> {
428 let config = self.config();
429 let client = reqwest::Client::new();
430
431 let mut params = vec![
432 ("grant_type", "authorization_code"),
433 ("client_id", client_id),
434 ("client_secret", client_secret),
435 ("code", authorization_code),
436 ("redirect_uri", redirect_uri),
437 ];
438
439 if let Some(verifier) = code_verifier {
441 params.push(("code_verifier", verifier));
442 }
443
444 let response = client
445 .post(&config.token_url)
446 .form(¶ms)
447 .send()
448 .await?;
449
450 if !response.status().is_success() {
451 let error_text = response.text().await.unwrap_or_default();
452 return Err(AuthError::auth_method(
453 self.name(),
454 format!("Token exchange failed: {error_text}"),
455 ));
456 }
457
458 let token_response: OAuthTokenResponse = response.json().await?;
459 Ok(token_response)
460 }
461
462 pub async fn refresh_token(
464 &self,
465 client_id: &str,
466 client_secret: &str,
467 refresh_token: &str,
468 ) -> Result<OAuthTokenResponse> {
469 let config = self.config();
470
471 if !config.supports_refresh {
472 return Err(AuthError::auth_method(
473 self.name(),
474 "Provider does not support token refresh".to_string(),
475 ));
476 }
477
478 let client = reqwest::Client::new();
479
480 let params = vec![
481 ("grant_type", "refresh_token"),
482 ("client_id", client_id),
483 ("client_secret", client_secret),
484 ("refresh_token", refresh_token),
485 ];
486
487 let response = client
488 .post(&config.token_url)
489 .form(¶ms)
490 .send()
491 .await?;
492
493 if !response.status().is_success() {
494 let error_text = response.text().await.unwrap_or_default();
495 return Err(AuthError::auth_method(
496 self.name(),
497 format!("Token refresh failed: {error_text}"),
498 ));
499 }
500
501 let token_response: OAuthTokenResponse = response.json().await?;
502 Ok(token_response)
503 }
504
505 pub async fn get_user_info(&self, access_token: &str) -> Result<OAuthUserInfo> {
507 let config = self.config();
508
509 let userinfo_url = config.userinfo_url.ok_or_else(|| {
510 AuthError::auth_method(
511 self.name(),
512 "Provider does not support user info endpoint".to_string(),
513 )
514 })?;
515
516 let client = reqwest::Client::new();
517 let response = client
518 .get(&userinfo_url)
519 .bearer_auth(access_token)
520 .send()
521 .await?;
522
523 if !response.status().is_success() {
524 let error_text = response.text().await.unwrap_or_default();
525 return Err(AuthError::auth_method(
526 self.name(),
527 format!("User info request failed: {error_text}"),
528 ));
529 }
530
531 let user_data: serde_json::Value = response.json().await?;
532
533 let user_info = self.parse_user_info(user_data)?;
535 Ok(user_info)
536 }
537
538 fn parse_user_info(&self, data: serde_json::Value) -> Result<OAuthUserInfo> {
540 let mut additional_fields = HashMap::new();
541
542 let user_info = match self {
543 Self::GitHub => {
544 let id = data["id"].as_u64()
545 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
546 .to_string();
547
548 OAuthUserInfo {
549 id,
550 username: data["login"].as_str().map(|s| s.to_string()),
551 email: data["email"].as_str().map(|s| s.to_string()),
552 name: data["name"].as_str().map(|s| s.to_string()),
553 picture: data["avatar_url"].as_str().map(|s| s.to_string()),
554 email_verified: None, locale: None,
556 additional_fields,
557 }
558 }
559
560 Self::Google => {
561 let id = data["id"].as_str()
562 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
563 .to_string();
564
565 OAuthUserInfo {
566 id,
567 username: None, email: data["email"].as_str().map(|s| s.to_string()),
569 name: data["name"].as_str().map(|s| s.to_string()),
570 picture: data["picture"].as_str().map(|s| s.to_string()),
571 email_verified: data["verified_email"].as_bool(),
572 locale: data["locale"].as_str().map(|s| s.to_string()),
573 additional_fields,
574 }
575 }
576
577 _ => {
579 let id = data["id"].as_str()
581 .or_else(|| data["sub"].as_str())
582 .or_else(|| data["user_id"].as_str())
583 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
584 .to_string();
585
586 if let serde_json::Value::Object(map) = data {
588 additional_fields = map.into_iter().collect();
589 }
590
591 OAuthUserInfo {
592 id,
593 username: additional_fields.get("username")
594 .or_else(|| additional_fields.get("login"))
595 .and_then(|v| v.as_str())
596 .map(|s| s.to_string()),
597 email: additional_fields.get("email")
598 .and_then(|v| v.as_str())
599 .map(|s| s.to_string()),
600 name: additional_fields.get("name")
601 .or_else(|| additional_fields.get("display_name"))
602 .and_then(|v| v.as_str())
603 .map(|s| s.to_string()),
604 picture: additional_fields.get("avatar_url")
605 .or_else(|| additional_fields.get("picture"))
606 .and_then(|v| v.as_str())
607 .map(|s| s.to_string()),
608 email_verified: additional_fields.get("email_verified")
609 .and_then(|v| v.as_bool()),
610 locale: additional_fields.get("locale")
611 .and_then(|v| v.as_str())
612 .map(|s| s.to_string()),
613 additional_fields,
614 }
615 }
616 };
617
618 Ok(user_info)
619 }
620
621 pub async fn revoke_token(&self, access_token: &str) -> Result<()> {
623 let config = self.config();
624
625 let revocation_url = config.revocation_url.ok_or_else(|| {
626 AuthError::auth_method(
627 self.name(),
628 "Provider does not support token revocation".to_string(),
629 )
630 })?;
631
632 let client = reqwest::Client::new();
633 let response = client
634 .post(&revocation_url)
635 .form(&[("token", access_token)])
636 .send()
637 .await?;
638
639 if !response.status().is_success() {
640 let error_text = response.text().await.unwrap_or_default();
641 return Err(AuthError::auth_method(
642 self.name(),
643 format!("Token revocation failed: {error_text}"),
644 ));
645 }
646
647 Ok(())
648 }
649
650 pub async fn device_authorization(
652 &self,
653 client_id: &str,
654 scope: Option<&[String]>,
655 ) -> Result<DeviceAuthorizationResponse> {
656 let config = self.config();
657
658 if !config.supports_device_flow {
659 return Err(AuthError::auth_method(
660 self.name(),
661 "Provider does not support device authorization flow".to_string(),
662 ));
663 }
664
665 let client = reqwest::Client::new();
666
667 let scope_string = scope.unwrap_or(&config.default_scopes).join(" ");
668 let params = vec![
669 ("client_id", client_id),
670 ("scope", scope_string.as_str()),
671 ];
672
673 let response = client
674 .post(config.device_authorization_url.as_deref().unwrap())
675 .form(¶ms)
676 .send()
677 .await?;
678
679 if !response.status().is_success() {
680 let error_text = response.text().await.unwrap_or_default();
681 return Err(AuthError::auth_method(
682 self.name(),
683 format!("Device authorization request failed: {error_text}"),
684 ));
685 }
686
687 let device_response: DeviceAuthorizationResponse = response.json().await?;
688 Ok(device_response)
689 }
690
691 pub async fn poll_device_code(
693 &self,
694 client_id: &str,
695 device_code: &str,
696 _interval: Option<u64>,
697 ) -> Result<OAuthTokenResponse> {
698 let config = self.config();
699
700 if !config.supports_device_flow {
701 return Err(AuthError::auth_method(
702 self.name(),
703 "Provider does not support device authorization flow".to_string(),
704 ));
705 }
706
707 let client = reqwest::Client::new();
708
709 let params = vec![
710 ("client_id", client_id),
711 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
712 ("device_code", device_code),
713 ];
714
715 let response = client
716 .post(&config.token_url)
717 .form(¶ms)
718 .send()
719 .await?;
720
721 if !response.status().is_success() {
722 let error_text = response.text().await.unwrap_or_default();
723 return Err(AuthError::auth_method(
724 self.name(),
725 format!("Token request failed: {error_text}"),
726 ));
727 }
728
729 let token_response: OAuthTokenResponse = response.json().await?;
730 Ok(token_response)
731 }
732}
733
734pub fn generate_state() -> String {
736 use rand::Rng;
737 let mut rng = rand::thread_rng();
738 (0..32)
739 .map(|_| rng.sample(rand::distributions::Alphanumeric) as char)
740 .collect()
741}
742
743pub fn generate_pkce() -> (String, String) {
745 use rand::Rng;
746 use ring::digest;
747
748 let mut rng = rand::thread_rng();
750 let code_verifier: String = (0..128)
751 .map(|_| rng.sample(rand::distributions::Alphanumeric) as char)
752 .collect();
753
754 let digest = digest::digest(&digest::SHA256, code_verifier.as_bytes()); let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest.as_ref());
756
757 (code_verifier, code_challenge)
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763
764 #[test]
765 fn test_provider_config() {
766 let github = OAuthProvider::GitHub;
767 let config = github.config();
768
769 assert_eq!(config.authorization_url, "https://github.com/login/oauth/authorize");
770 assert_eq!(config.token_url, "https://github.com/login/oauth/access_token");
771 assert!(config.supports_pkce);
772 }
773
774 #[test]
775 fn test_authorization_url() {
776 let github = OAuthProvider::GitHub;
777 let url = github.build_authorization_url(
778 "client123",
779 "https://example.com/callback",
780 "state123",
781 None,
782 Some("challenge123"),
783 ).unwrap();
784
785 assert!(url.contains("client_id=client123"));
786 assert!(url.contains("redirect_uri=https%3A%2F%2Fexample.com%2Fcallback"));
787 assert!(url.contains("state=state123"));
788 assert!(url.contains("code_challenge=challenge123"));
789 }
790
791 #[test]
792 fn test_generate_state() {
793 let state1 = generate_state();
794 let state2 = generate_state();
795
796 assert_eq!(state1.len(), 32);
797 assert_eq!(state2.len(), 32);
798 assert_ne!(state1, state2);
799 }
800
801 #[test]
802 fn test_generate_pkce() {
803 let (verifier1, challenge1) = generate_pkce();
804 let (verifier2, challenge2) = generate_pkce();
805
806 assert_eq!(verifier1.len(), 128);
807 assert_eq!(verifier2.len(), 128);
808 assert_ne!(verifier1, verifier2);
809 assert_ne!(challenge1, challenge2);
810 }
811}