auth_framework/
providers.rs

1//! OAuth provider configurations and implementations.
2impl Default for UserProfile {
3    fn default() -> Self {
4        Self::new()
5    }
6}
7use crate::errors::{AuthError, Result};
8use crate::tokens::AuthToken;
9use base64::Engine;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::fmt;
15use url::Url;
16
17/// Supported OAuth providers.
18#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
19pub enum OAuthProvider {
20    /// GitHub OAuth provider
21    GitHub,
22
23    /// Google OAuth provider
24    Google,
25
26    /// Microsoft OAuth provider
27    Microsoft,
28
29    /// Discord OAuth provider
30    Discord,
31
32    /// Twitter OAuth provider
33    Twitter,
34
35    /// Facebook OAuth provider
36    Facebook,
37
38    /// LinkedIn OAuth provider
39    LinkedIn,
40
41    /// GitLab OAuth provider
42    GitLab,
43
44    /// Generic OAuth provider with custom configuration
45    Custom {
46        name: String,
47        config: Box<OAuthProviderConfig>,
48    },
49}
50
51/// OAuth provider configuration.
52#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53pub struct OAuthProviderConfig {
54    /// Authorization endpoint URL
55    pub authorization_url: String,
56
57    /// Token endpoint URL
58    pub token_url: String,
59
60    /// Device authorization endpoint URL (for device flow)
61    pub device_authorization_url: Option<String>,
62
63    /// User info endpoint URL
64    pub userinfo_url: Option<String>,
65
66    /// Revocation endpoint URL
67    pub revocation_url: Option<String>,
68
69    /// Default scopes to request
70    pub default_scopes: Vec<String>,
71
72    /// Whether this provider supports PKCE
73    pub supports_pkce: bool,
74
75    /// Whether this provider supports refresh tokens
76    pub supports_refresh: bool,
77
78    /// Whether this provider supports device flow
79    pub supports_device_flow: bool,
80
81    /// Custom parameters to include in authorization requests
82    pub additional_params: HashMap<String, String>,
83}
84
85/// Device flow authorization response.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct DeviceAuthorizationResponse {
88    /// Device code
89    pub device_code: String,
90
91    /// User code that the user should enter
92    pub user_code: String,
93
94    /// URL where the user should verify the device
95    pub verification_uri: String,
96
97    /// Complete verification URL (optional)
98    pub verification_uri_complete: Option<String>,
99
100    /// Interval in seconds between polling requests
101    pub interval: u64,
102
103    /// Device code expires in seconds
104    pub expires_in: u64,
105}
106
107/// Standardized user profile across all providers.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct UserProfile {
110    /// Unique identifier from the provider
111    pub id: Option<String>,
112
113    /// Provider that authenticated this user
114    pub provider: Option<String>,
115
116    /// Username or login name
117    pub username: Option<String>,
118
119    /// Display name
120    pub name: Option<String>,
121
122    /// Email address
123    pub email: Option<String>,
124
125    /// Whether email is verified
126    pub email_verified: Option<bool>,
127
128    /// Profile picture URL
129    pub picture: Option<String>,
130
131    /// Locale/language preference
132    pub locale: Option<String>,
133
134    /// Provider-specific additional data
135    pub additional_data: HashMap<String, serde_json::Value>,
136}
137
138#[cfg(feature = "postgres-storage")]
139use sqlx::{Decode, Postgres, Type, postgres::PgValueRef};
140
141#[cfg(feature = "postgres-storage")]
142impl<'r> Decode<'r, Postgres> for UserProfile {
143    fn decode(value: PgValueRef<'r>) -> std::result::Result<Self, sqlx::error::BoxDynError> {
144        let json: serde_json::Value = <serde_json::Value as Decode<Postgres>>::decode(value)?;
145        serde_json::from_value(json).map_err(|e| Box::new(e) as sqlx::error::BoxDynError)
146    }
147}
148
149#[cfg(feature = "postgres-storage")]
150impl Type<Postgres> for UserProfile {
151    fn type_info() -> sqlx::postgres::PgTypeInfo {
152        <serde_json::Value as Type<Postgres>>::type_info()
153    }
154    fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
155        <serde_json::Value as Type<Postgres>>::compatible(ty)
156    }
157}
158
159impl UserProfile {
160    /// Create a new empty user profile
161    pub fn new() -> Self {
162        Self {
163            id: None,
164            provider: None,
165            username: None,
166            name: None,
167            email: None,
168            email_verified: None,
169            picture: None,
170            locale: None,
171            additional_data: HashMap::new(),
172        }
173    }
174
175    /// Set user ID
176    pub fn with_id(mut self, id: impl Into<String>) -> Self {
177        self.id = Some(id.into());
178        self
179    }
180
181    /// Set provider
182    pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
183        self.provider = Some(provider.into());
184        self
185    }
186
187    /// Set username
188    pub fn with_username(mut self, username: Option<impl Into<String>>) -> Self {
189        self.username = username.map(Into::into);
190        self
191    }
192
193    /// Set display name
194    pub fn with_name(mut self, name: Option<impl Into<String>>) -> Self {
195        self.name = name.map(Into::into);
196        self
197    }
198
199    /// Set email
200    pub fn with_email(mut self, email: Option<impl Into<String>>) -> Self {
201        self.email = email.map(Into::into);
202        self
203    }
204
205    /// Set email verification status
206    pub fn with_email_verified(mut self, verified: bool) -> Self {
207        self.email_verified = Some(verified);
208        self
209    }
210
211    /// Set profile picture URL
212    pub fn with_picture(mut self, picture: Option<impl Into<String>>) -> Self {
213        self.picture = picture.map(Into::into);
214        self
215    }
216
217    /// Set locale
218    pub fn with_locale(mut self, locale: Option<impl Into<String>>) -> Self {
219        self.locale = locale.map(Into::into);
220        self
221    }
222
223    /// Add additional provider-specific data
224    pub fn with_additional_data(
225        mut self,
226        key: impl Into<String>,
227        value: serde_json::Value,
228    ) -> Self {
229        self.additional_data.insert(key.into(), value);
230        self
231    }
232
233    /// Create a new user profile from an OAuth token response
234    pub fn from_token_response(
235        token: &OAuthTokenResponse,
236        provider: &OAuthProvider,
237    ) -> Option<Self> {
238        // Extract user info from ID token if present in additional fields
239        if let Some(id_token_value) = token.additional_fields.get("id_token")
240            && let Some(id_token) = id_token_value.as_str()
241            && let Ok(profile) = Self::from_id_token(id_token)
242        {
243            return Some(profile.with_provider(provider.to_string()));
244        }
245        None
246    }
247
248    /// Extract a user profile from an ID token (JWT)
249    pub fn from_id_token(id_token: &str) -> Result<Self> {
250        // Basic JWT parsing
251        let parts: Vec<&str> = id_token.split('.').collect();
252        if parts.len() != 3 {
253            return Err(AuthError::validation("Invalid JWT format"));
254        }
255
256        // Decode the payload (second part)
257        let payload = parts[1];
258        let padding_len = payload.len() % 4;
259        let padded_payload = if padding_len > 0 {
260            format!("{}{}", payload, "=".repeat(4 - padding_len))
261        } else {
262            payload.to_string()
263        };
264
265        // Decode base64
266        let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
267            .decode(&padded_payload)
268            .map_err(|e| AuthError::validation(format!("Failed to decode JWT: {}", e)))?;
269
270        // Parse JSON
271        let json: Value = serde_json::from_slice(&decoded)
272            .map_err(|e| AuthError::validation(format!("Failed to parse JWT payload: {}", e)))?;
273
274        // Extract common claims
275        let mut profile = Self::new();
276
277        // Try common ID fields
278        if let Some(sub) = json.get("sub").and_then(|v| v.as_str()) {
279            profile = profile.with_id(sub);
280        } else if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
281            profile = profile.with_id(id);
282        } else {
283            return Err(AuthError::validation("JWT missing subject claim"));
284        }
285
286        // Extract other common fields
287        if let Some(name) = json.get("name").and_then(|v| v.as_str()) {
288            profile = profile.with_name(Some(name));
289        }
290
291        if let Some(email) = json.get("email").and_then(|v| v.as_str()) {
292            profile = profile.with_email(Some(email));
293        }
294
295        if let Some(verified) = json.get("email_verified").and_then(|v| v.as_bool()) {
296            profile = profile.with_email_verified(verified);
297        }
298
299        if let Some(preferred_username) = json.get("preferred_username").and_then(|v| v.as_str()) {
300            profile = profile.with_username(Some(preferred_username));
301        }
302
303        if let Some(picture) = json.get("picture").and_then(|v| v.as_str()) {
304            profile = profile.with_picture(Some(picture));
305        }
306
307        if let Some(locale) = json.get("locale").and_then(|v| v.as_str()) {
308            profile = profile.with_locale(Some(locale));
309        }
310
311        // Store the entire claims as additional data
312        profile = profile.with_additional_data("id_token_claims", json);
313
314        Ok(profile)
315    }
316
317    /// Create an AuthToken with this profile's information
318    pub fn to_auth_token(&self, access_token: String) -> AuthToken {
319        let user_id = self.id.as_deref().unwrap_or("unknown").to_string();
320        let auth_method = self.provider.as_deref().unwrap_or("oauth").to_string();
321        let expires_in = std::time::Duration::from_secs(3600); // 1 hour default
322
323        let mut token = AuthToken::new(user_id.clone(), access_token, expires_in, auth_method);
324        token.subject = self.id.clone();
325        token.issuer = self.provider.clone();
326        token.user_profile = Some(self.clone());
327        token
328    }
329
330    /// Check if this profile has an ID
331    pub fn has_id(&self) -> bool {
332        self.id.is_some()
333    }
334
335    /// Get display name or fall back to username
336    pub fn display_name(&self) -> Option<&str> {
337        self.name.as_deref().or(self.username.as_deref())
338    }
339}
340
341/// OAuth token response from the provider.
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct OAuthTokenResponse {
344    /// Access token
345    pub access_token: String,
346
347    /// Token type (usually "Bearer")
348    pub token_type: String,
349
350    /// Token expiration in seconds
351    pub expires_in: Option<u64>,
352
353    /// Refresh token (if available)
354    pub refresh_token: Option<String>,
355
356    /// Granted scopes
357    pub scope: Option<String>,
358
359    /// Additional provider-specific fields
360    #[serde(flatten)]
361    pub additional_fields: HashMap<String, serde_json::Value>,
362}
363
364/// User information from OAuth provider.
365#[derive(Debug, Clone, Serialize, Deserialize)]
366pub struct OAuthUserInfo {
367    /// Unique user ID from the provider
368    pub id: String,
369
370    /// Username
371    pub username: Option<String>,
372
373    /// Display name
374    pub name: Option<String>,
375
376    /// Email address
377    pub email: Option<String>,
378
379    /// Whether email is verified
380    pub email_verified: Option<bool>,
381
382    /// Profile picture URL
383    pub picture: Option<String>,
384
385    /// Locale/language preference
386    pub locale: Option<String>,
387
388    /// Additional provider-specific fields
389    #[serde(flatten)]
390    pub additional_fields: HashMap<String, serde_json::Value>,
391}
392
393impl OAuthProvider {
394    /// Get the configuration for this provider.
395    pub fn config(&self) -> OAuthProviderConfig {
396        match self {
397            Self::GitHub => OAuthProviderConfig {
398                authorization_url: "https://github.com/login/oauth/authorize".to_string(),
399                token_url: "https://github.com/login/oauth/access_token".to_string(),
400                device_authorization_url: Some("https://github.com/login/device/code".to_string()),
401                userinfo_url: Some("https://api.github.com/user".to_string()),
402                revocation_url: None,
403                default_scopes: vec!["user:email".to_string()],
404                supports_pkce: true,
405                supports_refresh: false,
406                supports_device_flow: true,
407                additional_params: HashMap::new(),
408            },
409
410            Self::Google => OAuthProviderConfig {
411                authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
412                token_url: "https://oauth2.googleapis.com/token".to_string(),
413                device_authorization_url: Some(
414                    "https://oauth2.googleapis.com/device/code".to_string(),
415                ),
416                userinfo_url: Some("https://www.googleapis.com/oauth2/v2/userinfo".to_string()),
417                revocation_url: Some("https://oauth2.googleapis.com/revoke".to_string()),
418                default_scopes: vec![
419                    "openid".to_string(),
420                    "profile".to_string(),
421                    "email".to_string(),
422                ],
423                supports_pkce: true,
424                supports_refresh: true,
425                supports_device_flow: true,
426                additional_params: HashMap::new(),
427            },
428
429            Self::Microsoft => OAuthProviderConfig {
430                authorization_url: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
431                    .to_string(),
432                token_url: "https://login.microsoftonline.com/common/oauth2/v2.0/token".to_string(),
433                device_authorization_url: Some(
434                    "https://login.microsoftonline.com/common/oauth2/v2.0/devicecode".to_string(),
435                ),
436                userinfo_url: Some("https://graph.microsoft.com/v1.0/me".to_string()),
437                revocation_url: None,
438                default_scopes: vec![
439                    "openid".to_string(),
440                    "profile".to_string(),
441                    "email".to_string(),
442                ],
443                supports_pkce: true,
444                supports_refresh: true,
445                supports_device_flow: true,
446                additional_params: HashMap::new(),
447            },
448
449            Self::Discord => OAuthProviderConfig {
450                authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
451                token_url: "https://discord.com/api/oauth2/token".to_string(),
452                device_authorization_url: None,
453                userinfo_url: Some("https://discord.com/api/users/@me".to_string()),
454                revocation_url: Some("https://discord.com/api/oauth2/token/revoke".to_string()),
455                default_scopes: vec!["identify".to_string(), "email".to_string()],
456                supports_pkce: false,
457                supports_refresh: true,
458                supports_device_flow: false,
459                additional_params: HashMap::new(),
460            },
461
462            Self::Twitter => OAuthProviderConfig {
463                authorization_url: "https://twitter.com/i/oauth2/authorize".to_string(),
464                token_url: "https://api.twitter.com/2/oauth2/token".to_string(),
465                device_authorization_url: None,
466                userinfo_url: Some("https://api.twitter.com/2/users/me".to_string()),
467                revocation_url: Some("https://api.twitter.com/2/oauth2/revoke".to_string()),
468                default_scopes: vec!["tweet.read".to_string(), "users.read".to_string()],
469                supports_pkce: true,
470                supports_refresh: true,
471                supports_device_flow: false,
472                additional_params: HashMap::new(),
473            },
474
475            Self::Facebook => OAuthProviderConfig {
476                authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
477                token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
478                device_authorization_url: None,
479                userinfo_url: Some("https://graph.facebook.com/me".to_string()),
480                revocation_url: None,
481                default_scopes: vec!["email".to_string(), "public_profile".to_string()],
482                supports_pkce: false,
483                supports_refresh: false,
484                supports_device_flow: false,
485                additional_params: HashMap::new(),
486            },
487
488            Self::LinkedIn => OAuthProviderConfig {
489                authorization_url: "https://www.linkedin.com/oauth/v2/authorization".to_string(),
490                token_url: "https://www.linkedin.com/oauth/v2/accessToken".to_string(),
491                device_authorization_url: None,
492                userinfo_url: Some("https://api.linkedin.com/v2/me".to_string()),
493                revocation_url: None,
494                default_scopes: vec!["r_liteprofile".to_string(), "r_emailaddress".to_string()],
495                supports_pkce: false,
496                supports_refresh: true,
497                supports_device_flow: false,
498                additional_params: HashMap::new(),
499            },
500
501            Self::GitLab => OAuthProviderConfig {
502                authorization_url: "https://gitlab.com/oauth/authorize".to_string(),
503                token_url: "https://gitlab.com/oauth/token".to_string(),
504                device_authorization_url: None,
505                userinfo_url: Some("https://gitlab.com/api/v4/user".to_string()),
506                revocation_url: Some("https://gitlab.com/oauth/revoke".to_string()),
507                default_scopes: vec!["read_user".to_string()],
508                supports_pkce: true,
509                supports_refresh: true,
510                supports_device_flow: false,
511                additional_params: HashMap::new(),
512            },
513
514            Self::Custom { config, .. } => *config.clone(),
515        }
516    }
517
518    /// Get the provider name.
519    pub fn name(&self) -> &str {
520        match self {
521            Self::GitHub => "github",
522            Self::Google => "google",
523            Self::Microsoft => "microsoft",
524            Self::Discord => "discord",
525            Self::Twitter => "twitter",
526            Self::Facebook => "facebook",
527            Self::LinkedIn => "linkedin",
528            Self::GitLab => "gitlab",
529            Self::Custom { name, .. } => name,
530        }
531    }
532
533    /// Create a custom OAuth provider.
534    pub fn custom(name: impl Into<String>, config: OAuthProviderConfig) -> Self {
535        Self::Custom {
536            name: name.into(),
537            config: Box::new(config),
538        }
539    }
540
541    /// Build authorization URL.
542    pub fn build_authorization_url(
543        &self,
544        client_id: &str,
545        redirect_uri: &str,
546        state: &str,
547        scopes: Option<&[String]>,
548        code_challenge: Option<&str>,
549    ) -> Result<String> {
550        let config = self.config();
551        let mut url = Url::parse(&config.authorization_url)
552            .map_err(|e| AuthError::config(format!("Invalid authorization URL: {e}")))?;
553
554        let scopes = scopes.unwrap_or(&config.default_scopes);
555
556        {
557            let mut query = url.query_pairs_mut();
558            query.append_pair("client_id", client_id);
559            query.append_pair("redirect_uri", redirect_uri);
560            query.append_pair("response_type", "code");
561            query.append_pair("state", state);
562
563            if !scopes.is_empty() {
564                query.append_pair("scope", &scopes.join(" "));
565            }
566
567            // Add PKCE challenge if supported and provided (Clippy-compliant)
568            if config.supports_pkce
569                && let Some(challenge) = code_challenge
570            {
571                query.append_pair("code_challenge", challenge);
572                query.append_pair("code_challenge_method", "S256");
573            }
574
575            // Add any additional parameters
576            for (key, value) in &config.additional_params {
577                query.append_pair(key, value);
578            }
579        }
580
581        Ok(url.to_string())
582    }
583
584    /// Exchange authorization code for tokens.
585    pub async fn exchange_code(
586        &self,
587        client_id: &str,
588        client_secret: &str,
589        authorization_code: &str,
590        redirect_uri: &str,
591        code_verifier: Option<&str>,
592    ) -> Result<OAuthTokenResponse> {
593        let config = self.config();
594        let client = reqwest::Client::new();
595
596        let mut params = vec![
597            ("grant_type", "authorization_code"),
598            ("client_id", client_id),
599            ("client_secret", client_secret),
600            ("code", authorization_code),
601            ("redirect_uri", redirect_uri),
602        ];
603
604        // Add PKCE verifier if provided
605        if let Some(verifier) = code_verifier {
606            params.push(("code_verifier", verifier));
607        }
608
609        let response = client.post(&config.token_url).form(&params).send().await?;
610
611        if !response.status().is_success() {
612            let error_text = response.text().await.unwrap_or_default();
613            return Err(AuthError::auth_method(
614                self.name(),
615                format!("Token exchange failed: {error_text}"),
616            ));
617        }
618
619        let token_response: OAuthTokenResponse = response.json().await?;
620        Ok(token_response)
621    }
622
623    /// Refresh an access token.
624    pub async fn refresh_token(
625        &self,
626        client_id: &str,
627        client_secret: &str,
628        refresh_token: &str,
629    ) -> Result<OAuthTokenResponse> {
630        let config = self.config();
631
632        if !config.supports_refresh {
633            return Err(AuthError::auth_method(
634                self.name(),
635                "Provider does not support token refresh".to_string(),
636            ));
637        }
638
639        let client = reqwest::Client::new();
640
641        let params = vec![
642            ("grant_type", "refresh_token"),
643            ("client_id", client_id),
644            ("client_secret", client_secret),
645            ("refresh_token", refresh_token),
646        ];
647
648        let response = client.post(&config.token_url).form(&params).send().await?;
649
650        if !response.status().is_success() {
651            let error_text = response.text().await.unwrap_or_default();
652            return Err(AuthError::auth_method(
653                self.name(),
654                format!("Token refresh failed: {error_text}"),
655            ));
656        }
657
658        let token_response: OAuthTokenResponse = response.json().await?;
659        Ok(token_response)
660    }
661
662    /// Get user information using an access token.
663    pub async fn get_user_info(&self, access_token: &str) -> Result<OAuthUserInfo> {
664        let config = self.config();
665
666        let userinfo_url = config.userinfo_url.ok_or_else(|| {
667            AuthError::auth_method(
668                self.name(),
669                "Provider does not support user info endpoint".to_string(),
670            )
671        })?;
672
673        let client = reqwest::Client::new();
674        let response = client
675            .get(&userinfo_url)
676            .bearer_auth(access_token)
677            .send()
678            .await?;
679
680        if !response.status().is_success() {
681            let error_text = response.text().await.unwrap_or_default();
682            return Err(AuthError::auth_method(
683                self.name(),
684                format!("User info request failed: {error_text}"),
685            ));
686        }
687
688        let user_data: serde_json::Value = response.json().await?;
689
690        // Convert provider-specific user data to our standard format
691        let user_info = self.parse_user_info(user_data)?;
692        Ok(user_info)
693    }
694
695    /// Parse provider-specific user info into our standard format.
696    fn parse_user_info(&self, data: serde_json::Value) -> Result<OAuthUserInfo> {
697        let mut additional_fields = HashMap::new();
698
699        let user_info = match self {
700            Self::GitHub => {
701                let id = data["id"]
702                    .as_u64()
703                    .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
704                    .to_string();
705
706                OAuthUserInfo {
707                    id,
708                    username: data["login"].as_str().map(|s| s.to_string()),
709                    email: data["email"].as_str().map(|s| s.to_string()),
710                    name: data["name"].as_str().map(|s| s.to_string()),
711                    picture: data["avatar_url"].as_str().map(|s| s.to_string()),
712                    email_verified: None, // GitHub doesn't provide this directly
713                    locale: None,
714                    additional_fields,
715                }
716            }
717
718            Self::Google => {
719                let id = data["id"]
720                    .as_str()
721                    .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
722                    .to_string();
723
724                OAuthUserInfo {
725                    id,
726                    username: None, // Google doesn't provide username
727                    email: data["email"].as_str().map(|s| s.to_string()),
728                    name: data["name"].as_str().map(|s| s.to_string()),
729                    picture: data["picture"].as_str().map(|s| s.to_string()),
730                    email_verified: data["verified_email"].as_bool(),
731                    locale: data["locale"].as_str().map(|s| s.to_string()),
732                    additional_fields,
733                }
734            }
735
736            // Add other provider-specific parsing...
737            _ => {
738                // Generic parsing for custom providers
739                let id = data["id"]
740                    .as_str()
741                    .or_else(|| data["sub"].as_str())
742                    .or_else(|| data["user_id"].as_str())
743                    .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
744                    .to_string();
745
746                // Copy all fields to additional_fields for custom providers
747                if let serde_json::Value::Object(map) = data {
748                    additional_fields = map.into_iter().collect();
749                }
750
751                OAuthUserInfo {
752                    id,
753                    username: additional_fields
754                        .get("username")
755                        .or_else(|| additional_fields.get("login"))
756                        .and_then(|v| v.as_str())
757                        .map(|s| s.to_string()),
758                    email: additional_fields
759                        .get("email")
760                        .and_then(|v| v.as_str())
761                        .map(|s| s.to_string()),
762                    name: additional_fields
763                        .get("name")
764                        .or_else(|| additional_fields.get("display_name"))
765                        .and_then(|v| v.as_str())
766                        .map(|s| s.to_string()),
767                    picture: additional_fields
768                        .get("avatar_url")
769                        .or_else(|| additional_fields.get("picture"))
770                        .and_then(|v| v.as_str())
771                        .map(|s| s.to_string()),
772                    email_verified: additional_fields
773                        .get("email_verified")
774                        .and_then(|v| v.as_bool()),
775                    locale: additional_fields
776                        .get("locale")
777                        .and_then(|v| v.as_str())
778                        .map(|s| s.to_string()),
779                    additional_fields,
780                }
781            }
782        };
783
784        Ok(user_info)
785    }
786
787    /// Revoke a token if the provider supports it.
788    pub async fn revoke_token(&self, access_token: &str) -> Result<()> {
789        let config = self.config();
790
791        let revocation_url = config.revocation_url.ok_or_else(|| {
792            AuthError::auth_method(
793                self.name(),
794                "Provider does not support token revocation".to_string(),
795            )
796        })?;
797
798        let client = reqwest::Client::new();
799        let response = client
800            .post(&revocation_url)
801            .form(&[("token", access_token)])
802            .send()
803            .await?;
804
805        if !response.status().is_success() {
806            let error_text = response.text().await.unwrap_or_default();
807            return Err(AuthError::auth_method(
808                self.name(),
809                format!("Token revocation failed: {error_text}"),
810            ));
811        }
812
813        Ok(())
814    }
815
816    /// Perform device authorization flow.
817    pub async fn device_authorization(
818        &self,
819        client_id: &str,
820        scope: Option<&[String]>,
821    ) -> Result<DeviceAuthorizationResponse> {
822        let config = self.config();
823
824        if !config.supports_device_flow {
825            return Err(AuthError::auth_method(
826                self.name(),
827                "Provider does not support device authorization flow".to_string(),
828            ));
829        }
830
831        let client = reqwest::Client::new();
832
833        let scope_string = scope.unwrap_or(&config.default_scopes).join(" ");
834        let params = vec![("client_id", client_id), ("scope", scope_string.as_str())];
835
836        let response = client
837            .post(config.device_authorization_url.as_deref().unwrap())
838            .form(&params)
839            .send()
840            .await?;
841
842        if !response.status().is_success() {
843            let error_text = response.text().await.unwrap_or_default();
844            return Err(AuthError::auth_method(
845                self.name(),
846                format!("Device authorization request failed: {error_text}"),
847            ));
848        }
849
850        let device_response: DeviceAuthorizationResponse = response.json().await?;
851        Ok(device_response)
852    }
853
854    /// Poll for access token using device code.
855    pub async fn poll_device_code(
856        &self,
857        client_id: &str,
858        device_code: &str,
859        _interval: Option<u64>,
860    ) -> Result<OAuthTokenResponse> {
861        let config = self.config();
862
863        if !config.supports_device_flow {
864            return Err(AuthError::auth_method(
865                self.name(),
866                "Provider does not support device authorization flow".to_string(),
867            ));
868        }
869
870        let client = reqwest::Client::new();
871
872        let params = vec![
873            ("client_id", client_id),
874            ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
875            ("device_code", device_code),
876        ];
877
878        let response = client.post(&config.token_url).form(&params).send().await?;
879
880        if !response.status().is_success() {
881            let error_text = response.text().await.unwrap_or_default();
882            return Err(AuthError::auth_method(
883                self.name(),
884                format!("Token request failed: {error_text}"),
885            ));
886        }
887
888        let token_response: OAuthTokenResponse = response.json().await?;
889        Ok(token_response)
890    }
891}
892
893/// Generate a random state parameter for OAuth flows.
894pub fn generate_state() -> String {
895    let mut bytes = [0u8; 32];
896    use rand::RngCore;
897    rand::thread_rng().fill_bytes(&mut bytes);
898    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
899}
900
901/// Generate PKCE code verifier and challenge.
902pub fn generate_pkce() -> (String, String) {
903    use rand::RngCore;
904    use ring::digest;
905
906    // Generate code verifier (43-128 characters)
907    let mut rng = rand::thread_rng();
908    let mut bytes = [0u8; 96]; // 96 bytes = 128 base64 characters
909    rng.fill_bytes(&mut bytes);
910    let code_verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
911
912    // Generate code challenge (SHA256 hash of verifier, base64url encoded)
913    let digest = digest::digest(&digest::SHA256, code_verifier.as_bytes());
914    let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest.as_ref());
915
916    (code_verifier, code_challenge)
917}
918
919/// Automated token-to-profile conversion utilities
920pub struct ProfileExtractor {
921    client: Client,
922}
923
924impl ProfileExtractor {
925    /// Create a new profile extractor
926    pub fn new() -> Self {
927        Self {
928            client: Client::new(),
929        }
930    }
931
932    /// Extract user profile from token automatically based on provider
933    pub async fn extract_profile(
934        &self,
935        token: &AuthToken,
936        provider: &OAuthProvider,
937    ) -> Result<UserProfile> {
938        match provider {
939            OAuthProvider::GitHub => self.extract_github_profile(token).await,
940            OAuthProvider::Google => self.extract_google_profile(token).await,
941            OAuthProvider::Microsoft => self.extract_microsoft_profile(token).await,
942            OAuthProvider::Discord => self.extract_discord_profile(token).await,
943            OAuthProvider::GitLab => self.extract_gitlab_profile(token).await,
944            OAuthProvider::Custom { name, config } => {
945                self.extract_custom_profile(token, name, config).await
946            }
947            _ => Err(AuthError::UnsupportedProvider(format!(
948                "Profile extraction not supported for {:?}",
949                provider
950            ))),
951        }
952    }
953
954    /// Extract GitHub user profile
955    async fn extract_github_profile(&self, token: &AuthToken) -> Result<UserProfile> {
956        let response = self
957            .client
958            .get("https://api.github.com/user")
959            .bearer_auth(&token.access_token)
960            .send()
961            .await
962            .map_err(|e| AuthError::NetworkError(e.to_string()))?;
963
964        let json: Value = response
965            .json()
966            .await
967            .map_err(|e| AuthError::ParseError(e.to_string()))?;
968
969        let mut profile = UserProfile::new();
970        profile = profile.with_id(json["id"].as_u64().unwrap_or(0).to_string());
971        profile = profile.with_provider("github".to_string());
972
973        if let Some(login) = json["login"].as_str() {
974            profile.username = Some(login.to_string());
975        }
976
977        if let Some(name) = json["name"].as_str() {
978            profile.name = Some(name.to_string());
979        }
980
981        if let Some(email) = json["email"].as_str() {
982            profile.email = Some(email.to_string());
983        }
984
985        if let Some(avatar_url) = json["avatar_url"].as_str() {
986            profile.picture = Some(avatar_url.to_string());
987        }
988
989        // Store additional GitHub-specific data
990        if let Some(company) = json["company"].as_str() {
991            profile
992                .additional_data
993                .insert("company".to_string(), Value::String(company.to_string()));
994        }
995
996        if let Some(blog) = json["blog"].as_str() {
997            profile
998                .additional_data
999                .insert("blog".to_string(), Value::String(blog.to_string()));
1000        }
1001
1002        if let Some(bio) = json["bio"].as_str() {
1003            profile
1004                .additional_data
1005                .insert("bio".to_string(), Value::String(bio.to_string()));
1006        }
1007
1008        Ok(profile)
1009    }
1010
1011    /// Extract Google user profile
1012    async fn extract_google_profile(&self, token: &AuthToken) -> Result<UserProfile> {
1013        let response = self
1014            .client
1015            .get("https://www.googleapis.com/oauth2/v2/userinfo")
1016            .bearer_auth(&token.access_token)
1017            .send()
1018            .await
1019            .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1020
1021        let json: Value = response
1022            .json()
1023            .await
1024            .map_err(|e| AuthError::ParseError(e.to_string()))?;
1025
1026        let mut profile = UserProfile::new();
1027        profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1028        profile = profile.with_provider("google".to_string());
1029
1030        if let Some(name) = json["name"].as_str() {
1031            profile.name = Some(name.to_string());
1032        }
1033
1034        if let Some(email) = json["email"].as_str() {
1035            profile.email = Some(email.to_string());
1036        }
1037
1038        if let Some(verified) = json["verified_email"].as_bool() {
1039            profile.email_verified = Some(verified);
1040        }
1041
1042        if let Some(picture) = json["picture"].as_str() {
1043            profile.picture = Some(picture.to_string());
1044        }
1045
1046        if let Some(locale) = json["locale"].as_str() {
1047            profile.locale = Some(locale.to_string());
1048        }
1049
1050        Ok(profile)
1051    }
1052
1053    /// Extract Microsoft user profile
1054    async fn extract_microsoft_profile(&self, token: &AuthToken) -> Result<UserProfile> {
1055        let response = self
1056            .client
1057            .get("https://graph.microsoft.com/v1.0/me")
1058            .bearer_auth(&token.access_token)
1059            .send()
1060            .await
1061            .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1062
1063        let json: Value = response
1064            .json()
1065            .await
1066            .map_err(|e| AuthError::ParseError(e.to_string()))?;
1067
1068        let mut profile = UserProfile::new();
1069        profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1070        profile = profile.with_provider("microsoft".to_string());
1071
1072        if let Some(display_name) = json["displayName"].as_str() {
1073            profile.name = Some(display_name.to_string());
1074        }
1075
1076        if let Some(user_principal_name) = json["userPrincipalName"].as_str() {
1077            profile.username = Some(user_principal_name.to_string());
1078        }
1079
1080        if let Some(mail) = json["mail"].as_str() {
1081            profile.email = Some(mail.to_string());
1082        }
1083
1084        if let Some(preferred_language) = json["preferredLanguage"].as_str() {
1085            profile.locale = Some(preferred_language.to_string());
1086        }
1087
1088        // Store additional Microsoft-specific data
1089        if let Some(job_title) = json["jobTitle"].as_str() {
1090            profile
1091                .additional_data
1092                .insert("jobTitle".to_string(), Value::String(job_title.to_string()));
1093        }
1094
1095        if let Some(office_location) = json["officeLocation"].as_str() {
1096            profile.additional_data.insert(
1097                "officeLocation".to_string(),
1098                Value::String(office_location.to_string()),
1099            );
1100        }
1101
1102        Ok(profile)
1103    }
1104
1105    /// Extract Discord user profile
1106    async fn extract_discord_profile(&self, token: &AuthToken) -> Result<UserProfile> {
1107        let response = self
1108            .client
1109            .get("https://discord.com/api/users/@me")
1110            .bearer_auth(&token.access_token)
1111            .send()
1112            .await
1113            .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1114
1115        let json: Value = response
1116            .json()
1117            .await
1118            .map_err(|e| AuthError::ParseError(e.to_string()))?;
1119
1120        let mut profile = UserProfile::new();
1121        profile = profile.with_id(json["id"].as_str().unwrap_or("").to_string());
1122        profile = profile.with_provider("discord".to_string());
1123
1124        if let Some(username) = json["username"].as_str() {
1125            profile.username = Some(username.to_string());
1126        }
1127
1128        if let Some(discriminator) = json["discriminator"].as_str() {
1129            profile.name = Some(format!(
1130                "{}#{}",
1131                json["username"].as_str().unwrap_or(""),
1132                discriminator
1133            ));
1134        }
1135
1136        if let Some(email) = json["email"].as_str() {
1137            profile.email = Some(email.to_string());
1138        }
1139
1140        if let Some(verified) = json["verified"].as_bool() {
1141            profile.email_verified = Some(verified);
1142        }
1143
1144        if let Some(avatar) = json["avatar"].as_str() {
1145            let user_id = json["id"].as_str().unwrap_or("");
1146            profile.picture = Some(format!(
1147                "https://cdn.discordapp.com/avatars/{}/{}.png",
1148                user_id, avatar
1149            ));
1150        }
1151
1152        if let Some(locale) = json["locale"].as_str() {
1153            profile.locale = Some(locale.to_string());
1154        }
1155
1156        Ok(profile)
1157    }
1158
1159    /// Extract GitLab user profile
1160    async fn extract_gitlab_profile(&self, token: &AuthToken) -> Result<UserProfile> {
1161        let response = self
1162            .client
1163            .get("https://gitlab.com/api/v4/user")
1164            .bearer_auth(&token.access_token)
1165            .send()
1166            .await
1167            .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1168
1169        let json: Value = response
1170            .json()
1171            .await
1172            .map_err(|e| AuthError::ParseError(e.to_string()))?;
1173
1174        let mut profile = UserProfile::new();
1175        profile = profile.with_id(json["id"].as_u64().unwrap_or(0).to_string());
1176        profile = profile.with_provider("gitlab".to_string());
1177
1178        if let Some(username) = json["username"].as_str() {
1179            profile.username = Some(username.to_string());
1180        }
1181
1182        if let Some(name) = json["name"].as_str() {
1183            profile.name = Some(name.to_string());
1184        }
1185
1186        if let Some(email) = json["email"].as_str() {
1187            profile.email = Some(email.to_string());
1188        }
1189
1190        if let Some(avatar_url) = json["avatar_url"].as_str() {
1191            profile.picture = Some(avatar_url.to_string());
1192        }
1193
1194        // Store additional GitLab-specific data
1195        if let Some(web_url) = json["web_url"].as_str() {
1196            profile
1197                .additional_data
1198                .insert("web_url".to_string(), Value::String(web_url.to_string()));
1199        }
1200
1201        if let Some(bio) = json["bio"].as_str() {
1202            profile
1203                .additional_data
1204                .insert("bio".to_string(), Value::String(bio.to_string()));
1205        }
1206
1207        Ok(profile)
1208    }
1209
1210    /// Extract custom provider profile
1211    async fn extract_custom_profile(
1212        &self,
1213        token: &AuthToken,
1214        provider_name: &str,
1215        config: &OAuthProviderConfig,
1216    ) -> Result<UserProfile> {
1217        if let Some(user_info_url) = &config.userinfo_url {
1218            let response = self
1219                .client
1220                .get(user_info_url)
1221                .bearer_auth(&token.access_token)
1222                .send()
1223                .await
1224                .map_err(|e| AuthError::NetworkError(e.to_string()))?;
1225
1226            let json: Value = response
1227                .json()
1228                .await
1229                .map_err(|e| AuthError::ParseError(e.to_string()))?;
1230
1231            let mut profile = UserProfile::new();
1232            profile = profile.with_id(
1233                json["id"]
1234                    .as_str()
1235                    .or_else(|| json["sub"].as_str())
1236                    .unwrap_or("")
1237                    .to_string(),
1238            );
1239            profile = profile.with_provider(provider_name.to_string());
1240
1241            // Try common field names
1242            if let Some(username) = json["username"].as_str().or_else(|| json["login"].as_str()) {
1243                profile.username = Some(username.to_string());
1244            }
1245
1246            if let Some(name) = json["name"]
1247                .as_str()
1248                .or_else(|| json["display_name"].as_str())
1249            {
1250                profile.name = Some(name.to_string());
1251            }
1252
1253            if let Some(email) = json["email"].as_str() {
1254                profile.email = Some(email.to_string());
1255            }
1256
1257            if let Some(verified) = json["email_verified"]
1258                .as_bool()
1259                .or_else(|| json["verified"].as_bool())
1260            {
1261                profile.email_verified = Some(verified);
1262            }
1263
1264            if let Some(picture) = json["picture"]
1265                .as_str()
1266                .or_else(|| json["avatar_url"].as_str())
1267            {
1268                profile.picture = Some(picture.to_string());
1269            }
1270
1271            if let Some(locale) = json["locale"].as_str().or_else(|| json["lang"].as_str()) {
1272                profile.locale = Some(locale.to_string());
1273            }
1274
1275            // Store all additional data
1276            for (key, value) in json.as_object().unwrap_or(&serde_json::Map::new()) {
1277                if ![
1278                    "id",
1279                    "sub",
1280                    "username",
1281                    "login",
1282                    "name",
1283                    "display_name",
1284                    "email",
1285                    "email_verified",
1286                    "verified",
1287                    "picture",
1288                    "avatar_url",
1289                    "locale",
1290                    "lang",
1291                ]
1292                .contains(&key.as_str())
1293                {
1294                    profile.additional_data.insert(key.clone(), value.clone());
1295                }
1296            }
1297
1298            Ok(profile)
1299        } else {
1300            Err(AuthError::ConfigurationError(
1301                "Custom provider requires user_info_url".to_string(),
1302            ))
1303        }
1304    }
1305}
1306
1307impl Default for ProfileExtractor {
1308    fn default() -> Self {
1309        Self::new()
1310    }
1311}
1312
1313impl fmt::Display for OAuthProvider {
1314    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1315        match self {
1316            OAuthProvider::GitHub => write!(f, "github"),
1317            OAuthProvider::Google => write!(f, "google"),
1318            OAuthProvider::Microsoft => write!(f, "microsoft"),
1319            OAuthProvider::Discord => write!(f, "discord"),
1320            OAuthProvider::Twitter => write!(f, "twitter"),
1321            OAuthProvider::Facebook => write!(f, "facebook"),
1322            OAuthProvider::LinkedIn => write!(f, "linkedin"),
1323            OAuthProvider::GitLab => write!(f, "gitlab"),
1324            OAuthProvider::Custom { name, .. } => write!(f, "{}", name),
1325        }
1326    }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331    use super::*;
1332
1333    #[test]
1334    fn test_provider_config() {
1335        let github = OAuthProvider::GitHub;
1336        let config = github.config();
1337
1338        assert_eq!(
1339            config.authorization_url,
1340            "https://github.com/login/oauth/authorize"
1341        );
1342        assert_eq!(
1343            config.token_url,
1344            "https://github.com/login/oauth/access_token"
1345        );
1346        assert!(config.supports_pkce);
1347    }
1348
1349    #[test]
1350    fn test_authorization_url() {
1351        let github = OAuthProvider::GitHub;
1352        let url = github
1353            .build_authorization_url(
1354                "client123",
1355                "https://example.com/callback",
1356                "state123",
1357                None,
1358                Some("challenge123"),
1359            )
1360            .unwrap();
1361
1362        assert!(url.contains("client_id=client123"));
1363        assert!(url.contains("redirect_uri=https%3A%2F%2Fexample.com%2Fcallback"));
1364        assert!(url.contains("state=state123"));
1365        assert!(url.contains("code_challenge=challenge123"));
1366    }
1367
1368    #[test]
1369    fn test_generate_state() {
1370        let state1 = generate_state();
1371        let state2 = generate_state();
1372
1373        assert_eq!(state1.len(), 43);
1374        assert_eq!(state2.len(), 43);
1375        assert_ne!(state1, state2);
1376    }
1377
1378    #[test]
1379    fn test_generate_pkce() {
1380        let (verifier1, challenge1) = generate_pkce();
1381        let (verifier2, challenge2) = generate_pkce();
1382
1383        assert_eq!(verifier1.len(), 128);
1384        assert_eq!(verifier2.len(), 128);
1385        assert_ne!(verifier1, verifier2);
1386        assert_ne!(challenge1, challenge2);
1387    }
1388}
1389
1390