1use base64::Engine;
4use crate::errors::{AuthError, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use url::Url;
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub enum OAuthProvider {
12 GitHub,
14
15 Google,
17
18 Microsoft,
20
21 Discord,
23
24 Twitter,
26
27 Facebook,
29
30 LinkedIn,
32
33 GitLab,
35
36 Custom {
38 name: String,
39 config: OAuthProviderConfig,
40 },
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct OAuthProviderConfig {
46 pub authorization_url: String,
48
49 pub token_url: String,
51
52 pub userinfo_url: Option<String>,
54
55 pub revocation_url: Option<String>,
57
58 pub default_scopes: Vec<String>,
60
61 pub supports_pkce: bool,
63
64 pub supports_refresh: bool,
66
67 pub additional_params: HashMap<String, String>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct OAuthTokenResponse {
74 pub access_token: String,
76
77 pub token_type: String,
79
80 pub expires_in: Option<u64>,
82
83 pub refresh_token: Option<String>,
85
86 pub scope: Option<String>,
88
89 #[serde(flatten)]
91 pub additional_fields: HashMap<String, serde_json::Value>,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OAuthUserInfo {
97 pub id: String,
99
100 pub username: Option<String>,
102
103 pub email: Option<String>,
105
106 pub name: Option<String>,
108
109 pub avatar_url: Option<String>,
111
112 pub email_verified: Option<bool>,
114
115 pub locale: Option<String>,
117
118 pub additional_fields: HashMap<String, serde_json::Value>,
120}
121
122impl OAuthProvider {
123 pub fn config(&self) -> OAuthProviderConfig {
125 match self {
126 Self::GitHub => OAuthProviderConfig {
127 authorization_url: "https://github.com/login/oauth/authorize".to_string(),
128 token_url: "https://github.com/login/oauth/access_token".to_string(),
129 userinfo_url: Some("https://api.github.com/user".to_string()),
130 revocation_url: None,
131 default_scopes: vec!["user:email".to_string()],
132 supports_pkce: true,
133 supports_refresh: false,
134 additional_params: HashMap::new(),
135 },
136
137 Self::Google => OAuthProviderConfig {
138 authorization_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
139 token_url: "https://oauth2.googleapis.com/token".to_string(),
140 userinfo_url: Some("https://www.googleapis.com/oauth2/v2/userinfo".to_string()),
141 revocation_url: Some("https://oauth2.googleapis.com/revoke".to_string()),
142 default_scopes: vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
143 supports_pkce: true,
144 supports_refresh: true,
145 additional_params: HashMap::new(),
146 },
147
148 Self::Microsoft => OAuthProviderConfig {
149 authorization_url: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize".to_string(),
150 token_url: "https://login.microsoftonline.com/common/oauth2/v2.0/token".to_string(),
151 userinfo_url: Some("https://graph.microsoft.com/v1.0/me".to_string()),
152 revocation_url: None,
153 default_scopes: vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
154 supports_pkce: true,
155 supports_refresh: true,
156 additional_params: HashMap::new(),
157 },
158
159 Self::Discord => OAuthProviderConfig {
160 authorization_url: "https://discord.com/api/oauth2/authorize".to_string(),
161 token_url: "https://discord.com/api/oauth2/token".to_string(),
162 userinfo_url: Some("https://discord.com/api/users/@me".to_string()),
163 revocation_url: Some("https://discord.com/api/oauth2/token/revoke".to_string()),
164 default_scopes: vec!["identify".to_string(), "email".to_string()],
165 supports_pkce: false,
166 supports_refresh: true,
167 additional_params: HashMap::new(),
168 },
169
170 Self::Twitter => OAuthProviderConfig {
171 authorization_url: "https://twitter.com/i/oauth2/authorize".to_string(),
172 token_url: "https://api.twitter.com/2/oauth2/token".to_string(),
173 userinfo_url: Some("https://api.twitter.com/2/users/me".to_string()),
174 revocation_url: Some("https://api.twitter.com/2/oauth2/revoke".to_string()),
175 default_scopes: vec!["tweet.read".to_string(), "users.read".to_string()],
176 supports_pkce: true,
177 supports_refresh: true,
178 additional_params: HashMap::new(),
179 },
180
181 Self::Facebook => OAuthProviderConfig {
182 authorization_url: "https://www.facebook.com/v18.0/dialog/oauth".to_string(),
183 token_url: "https://graph.facebook.com/v18.0/oauth/access_token".to_string(),
184 userinfo_url: Some("https://graph.facebook.com/me".to_string()),
185 revocation_url: None,
186 default_scopes: vec!["email".to_string(), "public_profile".to_string()],
187 supports_pkce: false,
188 supports_refresh: false,
189 additional_params: HashMap::new(),
190 },
191
192 Self::LinkedIn => OAuthProviderConfig {
193 authorization_url: "https://www.linkedin.com/oauth/v2/authorization".to_string(),
194 token_url: "https://www.linkedin.com/oauth/v2/accessToken".to_string(),
195 userinfo_url: Some("https://api.linkedin.com/v2/me".to_string()),
196 revocation_url: None,
197 default_scopes: vec!["r_liteprofile".to_string(), "r_emailaddress".to_string()],
198 supports_pkce: false,
199 supports_refresh: true,
200 additional_params: HashMap::new(),
201 },
202
203 Self::GitLab => OAuthProviderConfig {
204 authorization_url: "https://gitlab.com/oauth/authorize".to_string(),
205 token_url: "https://gitlab.com/oauth/token".to_string(),
206 userinfo_url: Some("https://gitlab.com/api/v4/user".to_string()),
207 revocation_url: None,
208 default_scopes: vec!["read_user".to_string()],
209 supports_pkce: true,
210 supports_refresh: true,
211 additional_params: HashMap::new(),
212 },
213
214 Self::Custom { config, .. } => config.clone(),
215 }
216 }
217
218 pub fn name(&self) -> &str {
220 match self {
221 Self::GitHub => "github",
222 Self::Google => "google",
223 Self::Microsoft => "microsoft",
224 Self::Discord => "discord",
225 Self::Twitter => "twitter",
226 Self::Facebook => "facebook",
227 Self::LinkedIn => "linkedin",
228 Self::GitLab => "gitlab",
229 Self::Custom { name, .. } => name,
230 }
231 }
232
233 pub fn custom(name: impl Into<String>, config: OAuthProviderConfig) -> Self {
235 Self::Custom {
236 name: name.into(),
237 config,
238 }
239 }
240
241 pub fn build_authorization_url(
243 &self,
244 client_id: &str,
245 redirect_uri: &str,
246 state: &str,
247 scopes: Option<&[String]>,
248 code_challenge: Option<&str>,
249 ) -> Result<String> {
250 let config = self.config();
251 let mut url = Url::parse(&config.authorization_url)
252 .map_err(|e| AuthError::config(format!("Invalid authorization URL: {e}")))?;
253
254 let scopes = scopes.unwrap_or(&config.default_scopes);
255
256 {
257 let mut query = url.query_pairs_mut();
258 query.append_pair("client_id", client_id);
259 query.append_pair("redirect_uri", redirect_uri);
260 query.append_pair("response_type", "code");
261 query.append_pair("state", state);
262
263 if !scopes.is_empty() {
264 query.append_pair("scope", &scopes.join(" "));
265 }
266
267 if config.supports_pkce {
269 if let Some(challenge) = code_challenge {
270 query.append_pair("code_challenge", challenge);
271 query.append_pair("code_challenge_method", "S256");
272 }
273 }
274
275 for (key, value) in &config.additional_params {
277 query.append_pair(key, value);
278 }
279 }
280
281 Ok(url.to_string())
282 }
283
284 pub async fn exchange_code(
286 &self,
287 client_id: &str,
288 client_secret: &str,
289 authorization_code: &str,
290 redirect_uri: &str,
291 code_verifier: Option<&str>,
292 ) -> Result<OAuthTokenResponse> {
293 let config = self.config();
294 let client = reqwest::Client::new();
295
296 let mut params = vec![
297 ("grant_type", "authorization_code"),
298 ("client_id", client_id),
299 ("client_secret", client_secret),
300 ("code", authorization_code),
301 ("redirect_uri", redirect_uri),
302 ];
303
304 if let Some(verifier) = code_verifier {
306 params.push(("code_verifier", verifier));
307 }
308
309 let response = client
310 .post(&config.token_url)
311 .form(¶ms)
312 .send()
313 .await?;
314
315 if !response.status().is_success() {
316 let error_text = response.text().await.unwrap_or_default();
317 return Err(AuthError::auth_method(
318 self.name(),
319 format!("Token exchange failed: {error_text}"),
320 ));
321 }
322
323 let token_response: OAuthTokenResponse = response.json().await?;
324 Ok(token_response)
325 }
326
327 pub async fn refresh_token(
329 &self,
330 client_id: &str,
331 client_secret: &str,
332 refresh_token: &str,
333 ) -> Result<OAuthTokenResponse> {
334 let config = self.config();
335
336 if !config.supports_refresh {
337 return Err(AuthError::auth_method(
338 self.name(),
339 "Provider does not support token refresh".to_string(),
340 ));
341 }
342
343 let client = reqwest::Client::new();
344
345 let params = vec![
346 ("grant_type", "refresh_token"),
347 ("client_id", client_id),
348 ("client_secret", client_secret),
349 ("refresh_token", refresh_token),
350 ];
351
352 let response = client
353 .post(&config.token_url)
354 .form(¶ms)
355 .send()
356 .await?;
357
358 if !response.status().is_success() {
359 let error_text = response.text().await.unwrap_or_default();
360 return Err(AuthError::auth_method(
361 self.name(),
362 format!("Token refresh failed: {error_text}"),
363 ));
364 }
365
366 let token_response: OAuthTokenResponse = response.json().await?;
367 Ok(token_response)
368 }
369
370 pub async fn get_user_info(&self, access_token: &str) -> Result<OAuthUserInfo> {
372 let config = self.config();
373
374 let userinfo_url = config.userinfo_url.ok_or_else(|| {
375 AuthError::auth_method(
376 self.name(),
377 "Provider does not support user info endpoint".to_string(),
378 )
379 })?;
380
381 let client = reqwest::Client::new();
382 let response = client
383 .get(&userinfo_url)
384 .bearer_auth(access_token)
385 .send()
386 .await?;
387
388 if !response.status().is_success() {
389 let error_text = response.text().await.unwrap_or_default();
390 return Err(AuthError::auth_method(
391 self.name(),
392 format!("User info request failed: {error_text}"),
393 ));
394 }
395
396 let user_data: serde_json::Value = response.json().await?;
397
398 let user_info = self.parse_user_info(user_data)?;
400 Ok(user_info)
401 }
402
403 fn parse_user_info(&self, data: serde_json::Value) -> Result<OAuthUserInfo> {
405 let mut additional_fields = HashMap::new();
406
407 let user_info = match self {
408 Self::GitHub => {
409 let id = data["id"].as_u64()
410 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
411 .to_string();
412
413 OAuthUserInfo {
414 id,
415 username: data["login"].as_str().map(|s| s.to_string()),
416 email: data["email"].as_str().map(|s| s.to_string()),
417 name: data["name"].as_str().map(|s| s.to_string()),
418 avatar_url: data["avatar_url"].as_str().map(|s| s.to_string()),
419 email_verified: None, locale: None,
421 additional_fields,
422 }
423 }
424
425 Self::Google => {
426 let id = data["id"].as_str()
427 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
428 .to_string();
429
430 OAuthUserInfo {
431 id,
432 username: None, email: data["email"].as_str().map(|s| s.to_string()),
434 name: data["name"].as_str().map(|s| s.to_string()),
435 avatar_url: data["picture"].as_str().map(|s| s.to_string()),
436 email_verified: data["verified_email"].as_bool(),
437 locale: data["locale"].as_str().map(|s| s.to_string()),
438 additional_fields,
439 }
440 }
441
442 _ => {
444 let id = data["id"].as_str()
446 .or_else(|| data["sub"].as_str())
447 .or_else(|| data["user_id"].as_str())
448 .ok_or_else(|| AuthError::auth_method(self.name(), "Missing user ID"))?
449 .to_string();
450
451 if let serde_json::Value::Object(map) = data {
453 additional_fields = map.into_iter().collect();
454 }
455
456 OAuthUserInfo {
457 id,
458 username: additional_fields.get("username")
459 .or_else(|| additional_fields.get("login"))
460 .and_then(|v| v.as_str())
461 .map(|s| s.to_string()),
462 email: additional_fields.get("email")
463 .and_then(|v| v.as_str())
464 .map(|s| s.to_string()),
465 name: additional_fields.get("name")
466 .or_else(|| additional_fields.get("display_name"))
467 .and_then(|v| v.as_str())
468 .map(|s| s.to_string()),
469 avatar_url: additional_fields.get("avatar_url")
470 .or_else(|| additional_fields.get("picture"))
471 .and_then(|v| v.as_str())
472 .map(|s| s.to_string()),
473 email_verified: additional_fields.get("email_verified")
474 .and_then(|v| v.as_bool()),
475 locale: additional_fields.get("locale")
476 .and_then(|v| v.as_str())
477 .map(|s| s.to_string()),
478 additional_fields,
479 }
480 }
481 };
482
483 Ok(user_info)
484 }
485
486 pub async fn revoke_token(&self, access_token: &str) -> Result<()> {
488 let config = self.config();
489
490 let revocation_url = config.revocation_url.ok_or_else(|| {
491 AuthError::auth_method(
492 self.name(),
493 "Provider does not support token revocation".to_string(),
494 )
495 })?;
496
497 let client = reqwest::Client::new();
498 let response = client
499 .post(&revocation_url)
500 .form(&[("token", access_token)])
501 .send()
502 .await?;
503
504 if !response.status().is_success() {
505 let error_text = response.text().await.unwrap_or_default();
506 return Err(AuthError::auth_method(
507 self.name(),
508 format!("Token revocation failed: {error_text}"),
509 ));
510 }
511
512 Ok(())
513 }
514}
515
516pub fn generate_state() -> String {
518 use rand::Rng;
519 let mut rng = rand::thread_rng();
520 (0..32)
521 .map(|_| rng.sample(rand::distributions::Alphanumeric) as char)
522 .collect()
523}
524
525pub fn generate_pkce() -> (String, String) {
527 use rand::Rng;
528 use ring::digest;
529
530 let mut rng = rand::thread_rng();
532 let code_verifier: String = (0..128)
533 .map(|_| rng.sample(rand::distributions::Alphanumeric) as char)
534 .collect();
535
536 let digest = digest::digest(&digest::SHA256, code_verifier.as_bytes()); let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest.as_ref());
538
539 (code_verifier, code_challenge)
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_provider_config() {
548 let github = OAuthProvider::GitHub;
549 let config = github.config();
550
551 assert_eq!(config.authorization_url, "https://github.com/login/oauth/authorize");
552 assert_eq!(config.token_url, "https://github.com/login/oauth/access_token");
553 assert!(config.supports_pkce);
554 }
555
556 #[test]
557 fn test_authorization_url() {
558 let github = OAuthProvider::GitHub;
559 let url = github.build_authorization_url(
560 "client123",
561 "https://example.com/callback",
562 "state123",
563 None,
564 Some("challenge123"),
565 ).unwrap();
566
567 assert!(url.contains("client_id=client123"));
568 assert!(url.contains("redirect_uri=https%3A%2F%2Fexample.com%2Fcallback"));
569 assert!(url.contains("state=state123"));
570 assert!(url.contains("code_challenge=challenge123"));
571 }
572
573 #[test]
574 fn test_generate_state() {
575 let state1 = generate_state();
576 let state2 = generate_state();
577
578 assert_eq!(state1.len(), 32);
579 assert_eq!(state2.len(), 32);
580 assert_ne!(state1, state2);
581 }
582
583 #[test]
584 fn test_generate_pkce() {
585 let (verifier1, challenge1) = generate_pkce();
586 let (verifier2, challenge2) = generate_pkce();
587
588 assert_eq!(verifier1.len(), 128);
589 assert_eq!(verifier2.len(), 128);
590 assert_ne!(verifier1, verifier2);
591 assert_ne!(challenge1, challenge2);
592 }
593}