auth_framework/
providers.rs

1//! OAuth provider configurations and implementations.
2
3use base64::Engine;
4use crate::errors::{AuthError, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use url::Url;
8
9/// Supported OAuth providers.
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub enum OAuthProvider {
12    /// GitHub OAuth provider
13    GitHub,
14    
15    /// Google OAuth provider
16    Google,
17    
18    /// Microsoft OAuth provider
19    Microsoft,
20    
21    /// Discord OAuth provider
22    Discord,
23    
24    /// Twitter OAuth provider
25    Twitter,
26    
27    /// Facebook OAuth provider
28    Facebook,
29    
30    /// LinkedIn OAuth provider
31    LinkedIn,
32    
33    /// GitLab OAuth provider
34    GitLab,
35    
36    /// Generic OAuth provider with custom configuration
37    Custom {
38        name: String,
39        config: OAuthProviderConfig,
40    },
41}
42
43/// OAuth provider configuration.
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct OAuthProviderConfig {
46    /// Authorization endpoint URL
47    pub authorization_url: String,
48    
49    /// Token endpoint URL
50    pub token_url: String,
51    
52    /// User info endpoint URL
53    pub userinfo_url: Option<String>,
54    
55    /// Revocation endpoint URL
56    pub revocation_url: Option<String>,
57    
58    /// Default scopes to request
59    pub default_scopes: Vec<String>,
60    
61    /// Whether this provider supports PKCE
62    pub supports_pkce: bool,
63    
64    /// Whether this provider supports refresh tokens
65    pub supports_refresh: bool,
66    
67    /// Custom parameters to include in authorization requests
68    pub additional_params: HashMap<String, String>,
69}
70
71/// OAuth token response from the provider.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct OAuthTokenResponse {
74    /// Access token
75    pub access_token: String,
76    
77    /// Token type (usually "Bearer")
78    pub token_type: String,
79    
80    /// Token expiration in seconds
81    pub expires_in: Option<u64>,
82    
83    /// Refresh token (if available)
84    pub refresh_token: Option<String>,
85    
86    /// Granted scopes
87    pub scope: Option<String>,
88    
89    /// Additional provider-specific fields
90    #[serde(flatten)]
91    pub additional_fields: HashMap<String, serde_json::Value>,
92}
93
94/// User information from OAuth provider.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OAuthUserInfo {
97    /// Unique user ID from the provider
98    pub id: String,
99    
100    /// Username
101    pub username: Option<String>,
102    
103    /// Email address
104    pub email: Option<String>,
105    
106    /// Display name
107    pub name: Option<String>,
108    
109    /// Profile picture URL
110    pub avatar_url: Option<String>,
111    
112    /// Whether email is verified
113    pub email_verified: Option<bool>,
114    
115    /// User's locale
116    pub locale: Option<String>,
117    
118    /// Additional provider-specific fields
119    pub additional_fields: HashMap<String, serde_json::Value>,
120}
121
122impl OAuthProvider {
123    /// Get the configuration for this provider.
124    pub fn config(&self) -> OAuthProviderConfig {
125        match self {
126            Self::GitHub => OAuthProviderConfig {
127                authorization_url: "https://github.com/login/oauth/authorize".to_string(),
128                token_url: "https://github.com/login/oauth/access_token".to_string(),
129                userinfo_url: Some("https://api.github.com/user".to_string()),
130                revocation_url: None,
131                default_scopes: vec!["user:email".to_string()],
132                supports_pkce: true,
133                supports_refresh: false,
134                additional_params: HashMap::new(),
135            },
136            
137            Self::Google => OAuthProviderConfig {
138                authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
139                token_url: "https://oauth2.googleapis.com/token".to_string(),
140                userinfo_url: Some("https://www.googleapis.com/oauth2/v2/userinfo".to_string()),
141                revocation_url: Some("https://oauth2.googleapis.com/revoke".to_string()),
142                default_scopes: vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
143                supports_pkce: true,
144                supports_refresh: true,
145                additional_params: HashMap::new(),
146            },
147            
148            Self::Microsoft => OAuthProviderConfig {
149                authorization_url: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize".to_string(),
150                token_url: "https://login.microsoftonline.com/common/oauth2/v2.0/token".to_string(),
151                userinfo_url: Some("https://graph.microsoft.com/v1.0/me".to_string()),
152                revocation_url: None,
153                default_scopes: vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
154                supports_pkce: true,
155                supports_refresh: true,
156                additional_params: HashMap::new(),
157            },
158            
159            Self::Discord => OAuthProviderConfig {
160                authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
161                token_url: "https://discord.com/api/oauth2/token".to_string(),
162                userinfo_url: Some("https://discord.com/api/users/@me".to_string()),
163                revocation_url: Some("https://discord.com/api/oauth2/token/revoke".to_string()),
164                default_scopes: vec!["identify".to_string(), "email".to_string()],
165                supports_pkce: false,
166                supports_refresh: true,
167                additional_params: HashMap::new(),
168            },
169            
170            Self::Twitter => OAuthProviderConfig {
171                authorization_url: "https://twitter.com/i/oauth2/authorize".to_string(),
172                token_url: "https://api.twitter.com/2/oauth2/token".to_string(),
173                userinfo_url: Some("https://api.twitter.com/2/users/me".to_string()),
174                revocation_url: Some("https://api.twitter.com/2/oauth2/revoke".to_string()),
175                default_scopes: vec!["tweet.read".to_string(), "users.read".to_string()],
176                supports_pkce: true,
177                supports_refresh: true,
178                additional_params: HashMap::new(),
179            },
180            
181            Self::Facebook => OAuthProviderConfig {
182                authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
183                token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
184                userinfo_url: Some("https://graph.facebook.com/me".to_string()),
185                revocation_url: None,
186                default_scopes: vec!["email".to_string(), "public_profile".to_string()],
187                supports_pkce: false,
188                supports_refresh: false,
189                additional_params: HashMap::new(),
190            },
191            
192            Self::LinkedIn => OAuthProviderConfig {
193                authorization_url: "https://www.linkedin.com/oauth/v2/authorization".to_string(),
194                token_url: "https://www.linkedin.com/oauth/v2/accessToken".to_string(),
195                userinfo_url: Some("https://api.linkedin.com/v2/me".to_string()),
196                revocation_url: None,
197                default_scopes: vec!["r_liteprofile".to_string(), "r_emailaddress".to_string()],
198                supports_pkce: false,
199                supports_refresh: true,
200                additional_params: HashMap::new(),
201            },
202            
203            Self::GitLab => OAuthProviderConfig {
204                authorization_url: "https://gitlab.com/oauth/authorize".to_string(),
205                token_url: "https://gitlab.com/oauth/token".to_string(),
206                userinfo_url: Some("https://gitlab.com/api/v4/user".to_string()),
207                revocation_url: None,
208                default_scopes: vec!["read_user".to_string()],
209                supports_pkce: true,
210                supports_refresh: true,
211                additional_params: HashMap::new(),
212            },
213            
214            Self::Custom { config, .. } => config.clone(),
215        }
216    }
217
218    /// Get the provider name.
219    pub fn name(&self) -> &str {
220        match self {
221            Self::GitHub => "github",
222            Self::Google => "google",
223            Self::Microsoft => "microsoft",
224            Self::Discord => "discord",
225            Self::Twitter => "twitter",
226            Self::Facebook => "facebook",
227            Self::LinkedIn => "linkedin",
228            Self::GitLab => "gitlab",
229            Self::Custom { name, .. } => name,
230        }
231    }
232
233    /// Create a custom OAuth provider.
234    pub fn custom(name: impl Into<String>, config: OAuthProviderConfig) -> Self {
235        Self::Custom {
236            name: name.into(),
237            config,
238        }
239    }
240
241    /// Build authorization URL.
242    pub fn build_authorization_url(
243        &self,
244        client_id: &str,
245        redirect_uri: &str,
246        state: &str,
247        scopes: Option<&[String]>,
248        code_challenge: Option<&str>,
249    ) -> Result<String> {
250        let config = self.config();
251        let mut url = Url::parse(&config.authorization_url)
252            .map_err(|e| AuthError::config(format!("Invalid authorization URL: {e}")))?;
253
254        let scopes = scopes.unwrap_or(&config.default_scopes);
255        
256        {
257            let mut query = url.query_pairs_mut();
258            query.append_pair("client_id", client_id);
259            query.append_pair("redirect_uri", redirect_uri);
260            query.append_pair("response_type", "code");
261            query.append_pair("state", state);
262            
263            if !scopes.is_empty() {
264                query.append_pair("scope", &scopes.join(" "));
265            }
266
267            // Add PKCE challenge if supported and provided
268            if config.supports_pkce {
269                if let Some(challenge) = code_challenge {
270                    query.append_pair("code_challenge", challenge);
271                    query.append_pair("code_challenge_method", "S256");
272                }
273            }
274
275            // Add any additional parameters
276            for (key, value) in &config.additional_params {
277                query.append_pair(key, value);
278            }
279        }
280
281        Ok(url.to_string())
282    }
283
284    /// Exchange authorization code for tokens.
285    pub async fn exchange_code(
286        &self,
287        client_id: &str,
288        client_secret: &str,
289        authorization_code: &str,
290        redirect_uri: &str,
291        code_verifier: Option<&str>,
292    ) -> Result<OAuthTokenResponse> {
293        let config = self.config();
294        let client = reqwest::Client::new();
295
296        let mut params = vec![
297            ("grant_type", "authorization_code"),
298            ("client_id", client_id),
299            ("client_secret", client_secret),
300            ("code", authorization_code),
301            ("redirect_uri", redirect_uri),
302        ];
303
304        // Add PKCE verifier if provided
305        if let Some(verifier) = code_verifier {
306            params.push(("code_verifier", verifier));
307        }
308
309        let response = client
310            .post(&config.token_url)
311            .form(&params)
312            .send()
313            .await?;
314
315        if !response.status().is_success() {
316            let error_text = response.text().await.unwrap_or_default();
317            return Err(AuthError::auth_method(
318                self.name(),
319                format!("Token exchange failed: {error_text}"),
320            ));
321        }
322
323        let token_response: OAuthTokenResponse = response.json().await?;
324        Ok(token_response)
325    }
326
327    /// Refresh an access token.
328    pub async fn refresh_token(
329        &self,
330        client_id: &str,
331        client_secret: &str,
332        refresh_token: &str,
333    ) -> Result<OAuthTokenResponse> {
334        let config = self.config();
335        
336        if !config.supports_refresh {
337            return Err(AuthError::auth_method(
338                self.name(),
339                "Provider does not support token refresh".to_string(),
340            ));
341        }
342
343        let client = reqwest::Client::new();
344
345        let params = vec![
346            ("grant_type", "refresh_token"),
347            ("client_id", client_id),
348            ("client_secret", client_secret),
349            ("refresh_token", refresh_token),
350        ];
351
352        let response = client
353            .post(&config.token_url)
354            .form(&params)
355            .send()
356            .await?;
357
358        if !response.status().is_success() {
359            let error_text = response.text().await.unwrap_or_default();
360            return Err(AuthError::auth_method(
361                self.name(),
362                format!("Token refresh failed: {error_text}"),
363            ));
364        }
365
366        let token_response: OAuthTokenResponse = response.json().await?;
367        Ok(token_response)
368    }
369
370    /// Get user information using an access token.
371    pub async fn get_user_info(&self, access_token: &str) -> Result<OAuthUserInfo> {
372        let config = self.config();
373        
374        let userinfo_url = config.userinfo_url.ok_or_else(|| {
375            AuthError::auth_method(
376                self.name(),
377                "Provider does not support user info endpoint".to_string(),
378            )
379        })?;
380
381        let client = reqwest::Client::new();
382        let response = client
383            .get(&userinfo_url)
384            .bearer_auth(access_token)
385            .send()
386            .await?;
387
388        if !response.status().is_success() {
389            let error_text = response.text().await.unwrap_or_default();
390            return Err(AuthError::auth_method(
391                self.name(),
392                format!("User info request failed: {error_text}"),
393            ));
394        }
395
396        let user_data: serde_json::Value = response.json().await?;
397        
398        // Convert provider-specific user data to our standard format
399        let user_info = self.parse_user_info(user_data)?;
400        Ok(user_info)
401    }
402
403    /// Parse provider-specific user info into our standard format.
404    fn parse_user_info(&self, data: serde_json::Value) -> Result<OAuthUserInfo> {
405        let mut additional_fields = HashMap::new();
406        
407        let user_info = match self {
408            Self::GitHub => {
409                let id = data["id"].as_u64()
410                    .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
411                    .to_string();
412                
413                OAuthUserInfo {
414                    id,
415                    username: data["login"].as_str().map(|s| s.to_string()),
416                    email: data["email"].as_str().map(|s| s.to_string()),
417                    name: data["name"].as_str().map(|s| s.to_string()),
418                    avatar_url: data["avatar_url"].as_str().map(|s| s.to_string()),
419                    email_verified: None, // GitHub doesn't provide this directly
420                    locale: None,
421                    additional_fields,
422                }
423            }
424            
425            Self::Google => {
426                let id = data["id"].as_str()
427                    .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
428                    .to_string();
429                
430                OAuthUserInfo {
431                    id,
432                    username: None, // Google doesn't provide username
433                    email: data["email"].as_str().map(|s| s.to_string()),
434                    name: data["name"].as_str().map(|s| s.to_string()),
435                    avatar_url: data["picture"].as_str().map(|s| s.to_string()),
436                    email_verified: data["verified_email"].as_bool(),
437                    locale: data["locale"].as_str().map(|s| s.to_string()),
438                    additional_fields,
439                }
440            }
441            
442            // Add other provider-specific parsing...
443            _ => {
444                // Generic parsing for custom providers
445                let id = data["id"].as_str()
446                    .or_else(|| data["sub"].as_str())
447                    .or_else(|| data["user_id"].as_str())
448                    .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
449                    .to_string();
450
451                // Copy all fields to additional_fields for custom providers
452                if let serde_json::Value::Object(map) = data {
453                    additional_fields = map.into_iter().collect();
454                }
455                
456                OAuthUserInfo {
457                    id,
458                    username: additional_fields.get("username")
459                        .or_else(|| additional_fields.get("login"))
460                        .and_then(|v| v.as_str())
461                        .map(|s| s.to_string()),
462                    email: additional_fields.get("email")
463                        .and_then(|v| v.as_str())
464                        .map(|s| s.to_string()),
465                    name: additional_fields.get("name")
466                        .or_else(|| additional_fields.get("display_name"))
467                        .and_then(|v| v.as_str())
468                        .map(|s| s.to_string()),
469                    avatar_url: additional_fields.get("avatar_url")
470                        .or_else(|| additional_fields.get("picture"))
471                        .and_then(|v| v.as_str())
472                        .map(|s| s.to_string()),
473                    email_verified: additional_fields.get("email_verified")
474                        .and_then(|v| v.as_bool()),
475                    locale: additional_fields.get("locale")
476                        .and_then(|v| v.as_str())
477                        .map(|s| s.to_string()),
478                    additional_fields,
479                }
480            }
481        };
482
483        Ok(user_info)
484    }
485
486    /// Revoke a token if the provider supports it.
487    pub async fn revoke_token(&self, access_token: &str) -> Result<()> {
488        let config = self.config();
489        
490        let revocation_url = config.revocation_url.ok_or_else(|| {
491            AuthError::auth_method(
492                self.name(),
493                "Provider does not support token revocation".to_string(),
494            )
495        })?;
496
497        let client = reqwest::Client::new();
498        let response = client
499            .post(&revocation_url)
500            .form(&[("token", access_token)])
501            .send()
502            .await?;
503
504        if !response.status().is_success() {
505            let error_text = response.text().await.unwrap_or_default();
506            return Err(AuthError::auth_method(
507                self.name(),
508                format!("Token revocation failed: {error_text}"),
509            ));
510        }
511
512        Ok(())
513    }
514}
515
516/// Generate a random state parameter for OAuth flows.
517pub fn generate_state() -> String {
518    use rand::Rng;
519    let mut rng = rand::thread_rng();
520    (0..32)
521        .map(|_| rng.sample(rand::distributions::Alphanumeric) as char)
522        .collect()
523}
524
525/// Generate PKCE code verifier and challenge.
526pub fn generate_pkce() -> (String, String) {
527    use rand::Rng;
528    use ring::digest;
529    
530    // Generate code verifier (43-128 characters)
531    let mut rng = rand::thread_rng();
532    let code_verifier: String = (0..128)
533        .map(|_| rng.sample(rand::distributions::Alphanumeric) as char)
534        .collect();
535
536    // Generate code challenge (SHA256 hash of verifier, base64url encoded)
537    let digest = digest::digest(&digest::SHA256, code_verifier.as_bytes());        let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest.as_ref());
538
539    (code_verifier, code_challenge)
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[test]
547    fn test_provider_config() {
548        let github = OAuthProvider::GitHub;
549        let config = github.config();
550        
551        assert_eq!(config.authorization_url, "https://github.com/login/oauth/authorize");
552        assert_eq!(config.token_url, "https://github.com/login/oauth/access_token");
553        assert!(config.supports_pkce);
554    }
555
556    #[test]
557    fn test_authorization_url() {
558        let github = OAuthProvider::GitHub;
559        let url = github.build_authorization_url(
560            "client123",
561            "https://example.com/callback",
562            "state123",
563            None,
564            Some("challenge123"),
565        ).unwrap();
566
567        assert!(url.contains("client_id=client123"));
568        assert!(url.contains("redirect_uri=https%3A%2F%2Fexample.com%2Fcallback"));
569        assert!(url.contains("state=state123"));
570        assert!(url.contains("code_challenge=challenge123"));
571    }
572
573    #[test]
574    fn test_generate_state() {
575        let state1 = generate_state();
576        let state2 = generate_state();
577        
578        assert_eq!(state1.len(), 32);
579        assert_eq!(state2.len(), 32);
580        assert_ne!(state1, state2);
581    }
582
583    #[test]
584    fn test_generate_pkce() {
585        let (verifier1, challenge1) = generate_pkce();
586        let (verifier2, challenge2) = generate_pkce();
587        
588        assert_eq!(verifier1.len(), 128);
589        assert_eq!(verifier2.len(), 128);
590        assert_ne!(verifier1, verifier2);
591        assert_ne!(challenge1, challenge2);
592    }
593}