1use crate::AuthBackend;
4use async_trait::async_trait;
5use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
6use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
7use rusmes_proto::Username;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use tokio::sync::RwLock;
13
14#[derive(Debug, Clone)]
16pub enum OidcProvider {
17 Google {
19 client_id: String,
20 client_secret: String,
21 },
22 Microsoft {
24 tenant_id: String,
25 client_id: String,
26 client_secret: String,
27 },
28 Generic {
30 issuer_url: String,
31 client_id: String,
32 client_secret: String,
33 jwks_url: String,
34 },
35}
36
37#[derive(Debug, Serialize, Deserialize)]
39struct Claims {
40 sub: String,
41 email: Option<String>,
42 exp: u64,
43 iat: u64,
44 iss: String,
45 aud: String,
46}
47
48#[derive(Debug, Deserialize)]
50#[allow(dead_code)]
51struct IntrospectionResponse {
52 active: bool,
53 #[serde(default)]
54 username: Option<String>,
55 #[serde(default)]
56 email: Option<String>,
57 #[serde(default)]
58 exp: Option<u64>,
59}
60
61#[derive(Debug, Clone, Deserialize)]
63struct Jwks {
64 keys: Vec<Jwk>,
65}
66
67#[derive(Debug, Clone, Deserialize)]
69#[allow(dead_code)]
70struct Jwk {
71 kid: String,
72 kty: String,
73 #[serde(rename = "use")]
74 key_use: Option<String>,
75 alg: Option<String>,
76 n: Option<String>,
77 e: Option<String>,
78}
79
80#[derive(Debug, Clone)]
82#[allow(dead_code)]
83struct TokenCacheEntry {
84 username: String,
85 expires_at: SystemTime,
86}
87
88#[derive(Debug, Clone)]
90pub struct OAuth2Config {
91 pub provider: OidcProvider,
93 pub introspection_endpoint: Option<String>,
95 pub jwks_cache_ttl: u64,
97 pub enable_refresh_tokens: bool,
99 pub allowed_algorithms: Vec<Algorithm>,
101}
102
103impl Default for OAuth2Config {
104 fn default() -> Self {
105 Self {
106 provider: OidcProvider::Generic {
107 issuer_url: "https://example.com".to_string(),
108 client_id: "client-id".to_string(),
109 client_secret: "client-secret".to_string(),
110 jwks_url: "https://example.com/.well-known/jwks.json".to_string(),
111 },
112 introspection_endpoint: None,
113 jwks_cache_ttl: 3600,
114 enable_refresh_tokens: true,
115 allowed_algorithms: vec![Algorithm::RS256],
116 }
117 }
118}
119
120pub struct OAuth2Backend {
122 config: OAuth2Config,
123 token_cache: Arc<RwLock<HashMap<String, TokenCacheEntry>>>,
124 jwks_cache: Arc<RwLock<Option<(Jwks, SystemTime)>>>,
125 client: reqwest::Client,
126}
127
128impl OAuth2Backend {
129 pub fn new(config: OAuth2Config) -> Self {
131 Self {
132 config,
133 token_cache: Arc::new(RwLock::new(HashMap::new())),
134 jwks_cache: Arc::new(RwLock::new(None)),
135 client: reqwest::Client::new(),
136 }
137 }
138
139 pub fn parse_xoauth2_response(response: &str) -> anyhow::Result<(String, String)> {
143 let decoded = BASE64
145 .decode(response.as_bytes())
146 .map_err(|e| anyhow::anyhow!("Failed to decode XOAUTH2 response: {}", e))?;
147
148 let decoded_str = String::from_utf8(decoded)
149 .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in XOAUTH2 response: {}", e))?;
150
151 let parts: Vec<&str> = decoded_str.split('\x01').collect();
153
154 let mut username = None;
156 let mut token = None;
157
158 for part in &parts {
159 if part.starts_with("user=") {
160 username = part.strip_prefix("user=").map(|s| s.to_string());
161 } else if part.starts_with("auth=Bearer ") {
162 token = part.strip_prefix("auth=Bearer ").map(|s| s.to_string());
163 }
164 }
165
166 let username = username.ok_or_else(|| anyhow::anyhow!("Missing username in XOAUTH2"))?;
167 let token = token.ok_or_else(|| anyhow::anyhow!("Missing token in XOAUTH2"))?;
168
169 Ok((username, token))
170 }
171
172 #[allow(dead_code)]
176 pub fn encode_xoauth2_response(username: &str, token: &str) -> String {
177 let response = format!("user={}\x01auth=Bearer {}\x01\x01", username, token);
178 BASE64.encode(response.as_bytes())
179 }
180
181 pub async fn cleanup_expired_tokens(&self) {
183 let mut cache = self.token_cache.write().await;
184 let now = SystemTime::now();
185 cache.retain(|_, entry| entry.expires_at > now);
186 }
187
188 #[allow(dead_code)]
190 pub async fn token_cache_size(&self) -> usize {
191 let cache = self.token_cache.read().await;
192 cache.len()
193 }
194
195 #[allow(dead_code)]
197 pub async fn invalidate_token(&self, username: &str) {
198 let mut cache = self.token_cache.write().await;
199 cache.remove(username);
200 }
201
202 #[allow(dead_code)]
204 pub async fn clear_jwks_cache(&self) {
205 let mut cache = self.jwks_cache.write().await;
206 *cache = None;
207 }
208
209 async fn get_jwks(&self) -> anyhow::Result<Jwks> {
211 {
213 let cache = self.jwks_cache.read().await;
214 if let Some((jwks, cached_at)) = &*cache {
215 if cached_at.elapsed().unwrap_or(Duration::MAX).as_secs()
216 < self.config.jwks_cache_ttl
217 {
218 return Ok(jwks.clone());
219 }
220 }
221 }
222
223 let jwks_url = match &self.config.provider {
225 OidcProvider::Google { .. } => "https://www.googleapis.com/oauth2/v3/certs",
226 OidcProvider::Microsoft { tenant_id, .. } => &format!(
227 "https://login.microsoftonline.com/{}/discovery/v2.0/keys",
228 tenant_id
229 ),
230 OidcProvider::Generic { jwks_url, .. } => jwks_url.as_str(),
231 };
232
233 let jwks: Jwks = self.client.get(jwks_url).send().await?.json().await?;
234
235 {
237 let mut cache = self.jwks_cache.write().await;
238 *cache = Some((jwks.clone(), SystemTime::now()));
239 }
240
241 Ok(jwks)
242 }
243
244 async fn validate_jwt(&self, token: &str) -> anyhow::Result<Claims> {
246 let header = decode_header(token)?;
248 let kid = header
249 .kid
250 .ok_or_else(|| anyhow::anyhow!("No kid in JWT header"))?;
251
252 let jwks = self.get_jwks().await?;
254
255 let jwk = jwks
257 .keys
258 .iter()
259 .find(|k| k.kid == kid)
260 .ok_or_else(|| anyhow::anyhow!("No matching key found in JWKS"))?;
261
262 let n = jwk
264 .n
265 .as_ref()
266 .ok_or_else(|| anyhow::anyhow!("Missing n in JWK"))?;
267 let e = jwk
268 .e
269 .as_ref()
270 .ok_or_else(|| anyhow::anyhow!("Missing e in JWK"))?;
271
272 let n_bytes = BASE64.decode(n)?;
273 let e_bytes = BASE64.decode(e)?;
274
275 let decoding_key =
277 DecodingKey::from_rsa_components(&BASE64.encode(&n_bytes), &BASE64.encode(&e_bytes))?;
278
279 let mut validation = Validation::new(Algorithm::RS256);
281 validation.algorithms = self.config.allowed_algorithms.clone();
282
283 let expected_aud = match &self.config.provider {
285 OidcProvider::Google { client_id, .. } => client_id.clone(),
286 OidcProvider::Microsoft { client_id, .. } => client_id.clone(),
287 OidcProvider::Generic { client_id, .. } => client_id.clone(),
288 };
289 validation.set_audience(&[&expected_aud]);
290
291 let token_data = decode::<Claims>(token, &decoding_key, &validation)?;
292
293 Ok(token_data.claims)
294 }
295
296 async fn introspect_token(&self, token: &str) -> anyhow::Result<IntrospectionResponse> {
298 let endpoint = self
299 .config
300 .introspection_endpoint
301 .as_ref()
302 .ok_or_else(|| anyhow::anyhow!("Token introspection endpoint not configured"))?;
303
304 let (client_id, client_secret) = match &self.config.provider {
305 OidcProvider::Google {
306 client_id,
307 client_secret,
308 } => (client_id, client_secret),
309 OidcProvider::Microsoft {
310 client_id,
311 client_secret,
312 ..
313 } => (client_id, client_secret),
314 OidcProvider::Generic {
315 client_id,
316 client_secret,
317 ..
318 } => (client_id, client_secret),
319 };
320
321 let mut params = HashMap::new();
322 params.insert("token", token);
323 params.insert("client_id", client_id);
324 params.insert("client_secret", client_secret);
325
326 let response = self
327 .client
328 .post(endpoint)
329 .form(¶ms)
330 .send()
331 .await?
332 .json::<IntrospectionResponse>()
333 .await?;
334
335 Ok(response)
336 }
337
338 async fn xoauth2_authenticate(&self, token: &str) -> anyhow::Result<String> {
340 if let Ok(claims) = self.validate_jwt(token).await {
342 return Ok(claims.email.or(Some(claims.sub)).unwrap_or_default());
343 }
344
345 let introspection = self.introspect_token(token).await?;
347
348 if !introspection.active {
349 return Err(anyhow::anyhow!("Token is not active"));
350 }
351
352 introspection
353 .email
354 .or(introspection.username)
355 .ok_or_else(|| anyhow::anyhow!("No username in token"))
356 }
357
358 #[allow(dead_code)]
360 async fn refresh_token(&self, refresh_token: &str) -> anyhow::Result<String> {
361 if !self.config.enable_refresh_tokens {
362 return Err(anyhow::anyhow!("Refresh tokens not enabled"));
363 }
364
365 let token_endpoint = match &self.config.provider {
366 OidcProvider::Google { .. } => "https://oauth2.googleapis.com/token",
367 OidcProvider::Microsoft { tenant_id, .. } => &format!(
368 "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
369 tenant_id
370 ),
371 OidcProvider::Generic { issuer_url, .. } => &format!("{}/token", issuer_url),
372 };
373
374 let (client_id, client_secret) = match &self.config.provider {
375 OidcProvider::Google {
376 client_id,
377 client_secret,
378 } => (client_id, client_secret),
379 OidcProvider::Microsoft {
380 client_id,
381 client_secret,
382 ..
383 } => (client_id, client_secret),
384 OidcProvider::Generic {
385 client_id,
386 client_secret,
387 ..
388 } => (client_id, client_secret),
389 };
390
391 let mut params = HashMap::new();
392 params.insert("grant_type", "refresh_token");
393 params.insert("refresh_token", refresh_token);
394 params.insert("client_id", client_id);
395 params.insert("client_secret", client_secret);
396
397 #[derive(Deserialize)]
398 struct TokenResponse {
399 access_token: String,
400 }
401
402 let response = self
403 .client
404 .post(token_endpoint)
405 .form(¶ms)
406 .send()
407 .await?
408 .json::<TokenResponse>()
409 .await?;
410
411 Ok(response.access_token)
412 }
413}
414
415#[async_trait]
416impl AuthBackend for OAuth2Backend {
417 async fn authenticate(&self, username: &Username, password: &str) -> anyhow::Result<bool> {
418 let token = password;
420
421 {
423 let cache = self.token_cache.read().await;
424 if let Some(entry) = cache.get(&username.to_string()) {
425 if SystemTime::now() < entry.expires_at {
426 return Ok(true);
427 }
428 }
429 }
430
431 match self.xoauth2_authenticate(token).await {
433 Ok(token_username) => {
434 if token_username == username.to_string() {
435 let mut cache = self.token_cache.write().await;
437 cache.insert(
438 username.to_string(),
439 TokenCacheEntry {
440 username: token_username,
441 expires_at: SystemTime::now() + Duration::from_secs(300),
442 },
443 );
444 Ok(true)
445 } else {
446 Ok(false)
447 }
448 }
449 Err(_) => Ok(false),
450 }
451 }
452
453 async fn verify_identity(&self, username: &Username) -> anyhow::Result<bool> {
454 let cache = self.token_cache.read().await;
456 Ok(cache.contains_key(&username.to_string()))
457 }
458
459 async fn list_users(&self) -> anyhow::Result<Vec<Username>> {
460 let cache = self.token_cache.read().await;
462 Ok(cache
463 .keys()
464 .filter_map(|k| Username::new(k.clone()).ok())
465 .collect())
466 }
467
468 async fn create_user(&self, _username: &Username, _password: &str) -> anyhow::Result<()> {
469 Err(anyhow::anyhow!(
470 "OAuth2 backend does not support user creation (external provider)"
471 ))
472 }
473
474 async fn delete_user(&self, _username: &Username) -> anyhow::Result<()> {
475 Err(anyhow::anyhow!(
476 "OAuth2 backend does not support user deletion (external provider)"
477 ))
478 }
479
480 async fn change_password(
481 &self,
482 _username: &Username,
483 _new_password: &str,
484 ) -> anyhow::Result<()> {
485 Err(anyhow::anyhow!(
486 "OAuth2 backend does not support password changes (external provider)"
487 ))
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
500 fn test_oauth2_config_default() {
501 let config = OAuth2Config::default();
502 assert_eq!(config.jwks_cache_ttl, 3600);
503 assert!(config.enable_refresh_tokens);
504 assert_eq!(config.allowed_algorithms.len(), 1);
505 }
506
507 #[test]
508 fn test_oauth2_config_google() {
509 let config = OAuth2Config {
510 provider: OidcProvider::Google {
511 client_id: "test-client-id".to_string(),
512 client_secret: "test-secret".to_string(),
513 },
514 ..Default::default()
515 };
516 assert!(matches!(config.provider, OidcProvider::Google { .. }));
517 }
518
519 #[test]
520 fn test_oauth2_config_microsoft() {
521 let config = OAuth2Config {
522 provider: OidcProvider::Microsoft {
523 tenant_id: "test-tenant".to_string(),
524 client_id: "test-client".to_string(),
525 client_secret: "test-secret".to_string(),
526 },
527 ..Default::default()
528 };
529 assert!(matches!(config.provider, OidcProvider::Microsoft { .. }));
530 }
531
532 #[test]
533 fn test_oauth2_config_generic() {
534 let config = OAuth2Config {
535 provider: OidcProvider::Generic {
536 issuer_url: "https://oidc.example.com".to_string(),
537 client_id: "client".to_string(),
538 client_secret: "secret".to_string(),
539 jwks_url: "https://oidc.example.com/jwks".to_string(),
540 },
541 ..Default::default()
542 };
543 assert!(matches!(config.provider, OidcProvider::Generic { .. }));
544 }
545
546 #[test]
547 fn test_allowed_algorithms() {
548 let config = OAuth2Config {
549 allowed_algorithms: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512],
550 ..Default::default()
551 };
552 assert_eq!(config.allowed_algorithms.len(), 3);
553 }
554
555 #[test]
556 fn test_introspection_endpoint_optional() {
557 let config = OAuth2Config::default();
558 assert!(config.introspection_endpoint.is_none());
559
560 let config_with_introspection = OAuth2Config {
561 introspection_endpoint: Some("https://example.com/introspect".to_string()),
562 ..Default::default()
563 };
564 assert!(config_with_introspection.introspection_endpoint.is_some());
565 }
566
567 #[test]
568 fn test_refresh_tokens_enabled() {
569 let config = OAuth2Config {
570 enable_refresh_tokens: true,
571 ..Default::default()
572 };
573 assert!(config.enable_refresh_tokens);
574
575 let config_disabled = OAuth2Config {
576 enable_refresh_tokens: false,
577 ..Default::default()
578 };
579 assert!(!config_disabled.enable_refresh_tokens);
580 }
581
582 #[test]
583 fn test_jwks_cache_ttl() {
584 let config = OAuth2Config {
585 jwks_cache_ttl: 7200,
586 ..Default::default()
587 };
588 assert_eq!(config.jwks_cache_ttl, 7200);
589 }
590
591 #[test]
592 fn test_config_clone() {
593 let config = OAuth2Config::default();
594 let cloned = config.clone();
595 assert_eq!(config.jwks_cache_ttl, cloned.jwks_cache_ttl);
596 }
597
598 #[tokio::test]
603 async fn test_oauth2_backend_creation() {
604 let config = OAuth2Config::default();
605 let backend = OAuth2Backend::new(config);
606 let cache = backend.token_cache.read().await;
607 assert_eq!(cache.len(), 0);
608 }
609
610 #[tokio::test]
611 async fn test_token_cache_empty_on_creation() {
612 let backend = OAuth2Backend::new(OAuth2Config::default());
613 let cache = backend.token_cache.read().await;
614 assert!(cache.is_empty());
615 }
616
617 #[tokio::test]
618 async fn test_jwks_cache_empty_on_creation() {
619 let backend = OAuth2Backend::new(OAuth2Config::default());
620 let cache = backend.jwks_cache.read().await;
621 assert!(cache.is_none());
622 }
623
624 #[tokio::test]
629 async fn test_create_user_not_supported() {
630 let backend = OAuth2Backend::new(OAuth2Config::default());
631 let username = Username::new("user@example.com".to_string()).unwrap();
632 let result = backend.create_user(&username, "token").await;
633 assert!(result.is_err());
634 assert!(result
635 .unwrap_err()
636 .to_string()
637 .contains("external provider"));
638 }
639
640 #[tokio::test]
641 async fn test_delete_user_not_supported() {
642 let backend = OAuth2Backend::new(OAuth2Config::default());
643 let username = Username::new("user@example.com".to_string()).unwrap();
644 let result = backend.delete_user(&username).await;
645 assert!(result.is_err());
646 assert!(result
647 .unwrap_err()
648 .to_string()
649 .contains("external provider"));
650 }
651
652 #[tokio::test]
653 async fn test_change_password_not_supported() {
654 let backend = OAuth2Backend::new(OAuth2Config::default());
655 let username = Username::new("user@example.com".to_string()).unwrap();
656 let result = backend.change_password(&username, "newtoken").await;
657 assert!(result.is_err());
658 assert!(result
659 .unwrap_err()
660 .to_string()
661 .contains("external provider"));
662 }
663
664 #[tokio::test]
665 async fn test_list_users_empty() {
666 let backend = OAuth2Backend::new(OAuth2Config::default());
667 let users = backend.list_users().await.unwrap();
668 assert_eq!(users.len(), 0);
669 }
670
671 #[tokio::test]
672 async fn test_verify_identity_not_cached() {
673 let backend = OAuth2Backend::new(OAuth2Config::default());
674 let username = Username::new("user@example.com".to_string()).unwrap();
675 let verified = backend.verify_identity(&username).await.unwrap();
676 assert!(!verified);
677 }
678
679 #[tokio::test]
680 async fn test_verify_identity_cached() {
681 let backend = OAuth2Backend::new(OAuth2Config::default());
682 let username = Username::new("cached@example.com".to_string()).unwrap();
683
684 {
685 let mut cache = backend.token_cache.write().await;
686 cache.insert(
687 username.to_string(),
688 TokenCacheEntry {
689 username: username.to_string(),
690 expires_at: SystemTime::now() + Duration::from_secs(300),
691 },
692 );
693 }
694
695 let verified = backend.verify_identity(&username).await.unwrap();
696 assert!(verified);
697 }
698
699 #[tokio::test]
704 async fn test_token_cache_insertion() {
705 let backend = OAuth2Backend::new(OAuth2Config::default());
706
707 {
708 let mut cache = backend.token_cache.write().await;
709 cache.insert(
710 "user@example.com".to_string(),
711 TokenCacheEntry {
712 username: "user@example.com".to_string(),
713 expires_at: SystemTime::now() + Duration::from_secs(300),
714 },
715 );
716 }
717
718 let cache = backend.token_cache.read().await;
719 assert_eq!(cache.len(), 1);
720 assert!(cache.contains_key("user@example.com"));
721 }
722
723 #[tokio::test]
724 async fn test_token_cache_expiration() {
725 let backend = OAuth2Backend::new(OAuth2Config::default());
726
727 {
728 let mut cache = backend.token_cache.write().await;
729 cache.insert(
730 "expired@example.com".to_string(),
731 TokenCacheEntry {
732 username: "expired@example.com".to_string(),
733 expires_at: SystemTime::now() - Duration::from_secs(1),
734 },
735 );
736 }
737
738 let cache = backend.token_cache.read().await;
740 let entry = cache.get("expired@example.com").unwrap();
741 assert!(entry.expires_at < SystemTime::now());
742 }
743
744 #[tokio::test]
745 async fn test_token_cache_multiple_users() {
746 let backend = OAuth2Backend::new(OAuth2Config::default());
747
748 {
749 let mut cache = backend.token_cache.write().await;
750 for i in 1..=5 {
751 cache.insert(
752 format!("user{}@example.com", i),
753 TokenCacheEntry {
754 username: format!("user{}@example.com", i),
755 expires_at: SystemTime::now() + Duration::from_secs(300),
756 },
757 );
758 }
759 }
760
761 let cache = backend.token_cache.read().await;
762 assert_eq!(cache.len(), 5);
763 }
764
765 #[tokio::test]
766 async fn test_list_users_with_cached_tokens() {
767 let backend = OAuth2Backend::new(OAuth2Config::default());
768
769 {
770 let mut cache = backend.token_cache.write().await;
771 cache.insert(
772 "user1@example.com".to_string(),
773 TokenCacheEntry {
774 username: "user1@example.com".to_string(),
775 expires_at: SystemTime::now() + Duration::from_secs(300),
776 },
777 );
778 cache.insert(
779 "user2@example.com".to_string(),
780 TokenCacheEntry {
781 username: "user2@example.com".to_string(),
782 expires_at: SystemTime::now() + Duration::from_secs(300),
783 },
784 );
785 }
786
787 let users = backend.list_users().await.unwrap();
788 assert_eq!(users.len(), 2);
789 }
790
791 #[test]
796 fn test_claims_structure() {
797 let claims = Claims {
798 sub: "user123".to_string(),
799 email: Some("user@example.com".to_string()),
800 exp: 1234567890,
801 iat: 1234567800,
802 iss: "https://accounts.google.com".to_string(),
803 aud: "client-id".to_string(),
804 };
805 assert_eq!(claims.sub, "user123");
806 assert_eq!(claims.email.unwrap(), "user@example.com");
807 }
808
809 #[test]
810 fn test_claims_without_email() {
811 let claims = Claims {
812 sub: "user123".to_string(),
813 email: None,
814 exp: 1234567890,
815 iat: 1234567800,
816 iss: "https://accounts.google.com".to_string(),
817 aud: "client-id".to_string(),
818 };
819 assert_eq!(claims.sub, "user123");
820 assert!(claims.email.is_none());
821 }
822
823 #[test]
828 fn test_token_cache_entry() {
829 let entry = TokenCacheEntry {
830 username: "user@example.com".to_string(),
831 expires_at: SystemTime::now() + Duration::from_secs(300),
832 };
833 assert_eq!(entry.username, "user@example.com");
834 assert!(entry.expires_at > SystemTime::now());
835 }
836
837 #[test]
838 fn test_token_cache_entry_expired() {
839 let entry = TokenCacheEntry {
840 username: "user@example.com".to_string(),
841 expires_at: SystemTime::now() - Duration::from_secs(10),
842 };
843 assert!(entry.expires_at < SystemTime::now());
844 }
845
846 #[test]
851 fn test_google_provider_config() {
852 let provider = OidcProvider::Google {
853 client_id: "google-client-id".to_string(),
854 client_secret: "google-secret".to_string(),
855 };
856
857 if let OidcProvider::Google { client_id, .. } = &provider {
858 assert_eq!(client_id, "google-client-id");
859 } else {
860 panic!("Expected Google provider");
861 }
862 }
863
864 #[test]
865 fn test_microsoft_provider_config() {
866 let provider = OidcProvider::Microsoft {
867 tenant_id: "tenant-123".to_string(),
868 client_id: "ms-client-id".to_string(),
869 client_secret: "ms-secret".to_string(),
870 };
871
872 if let OidcProvider::Microsoft { tenant_id, .. } = &provider {
873 assert_eq!(tenant_id, "tenant-123");
874 } else {
875 panic!("Expected Microsoft provider");
876 }
877 }
878
879 #[test]
880 fn test_generic_provider_config() {
881 let provider = OidcProvider::Generic {
882 issuer_url: "https://auth.example.com".to_string(),
883 client_id: "generic-client".to_string(),
884 client_secret: "generic-secret".to_string(),
885 jwks_url: "https://auth.example.com/.well-known/jwks.json".to_string(),
886 };
887
888 if let OidcProvider::Generic { issuer_url, .. } = &provider {
889 assert_eq!(issuer_url, "https://auth.example.com");
890 } else {
891 panic!("Expected Generic provider");
892 }
893 }
894
895 #[test]
900 fn test_multiple_allowed_algorithms() {
901 let config = OAuth2Config {
902 allowed_algorithms: vec![
903 Algorithm::RS256,
904 Algorithm::RS384,
905 Algorithm::RS512,
906 Algorithm::ES256,
907 ],
908 ..Default::default()
909 };
910 assert_eq!(config.allowed_algorithms.len(), 4);
911 assert!(config.allowed_algorithms.contains(&Algorithm::RS256));
912 assert!(config.allowed_algorithms.contains(&Algorithm::ES256));
913 }
914
915 #[test]
916 fn test_single_algorithm_rs256() {
917 let config = OAuth2Config {
918 allowed_algorithms: vec![Algorithm::RS256],
919 ..Default::default()
920 };
921 assert_eq!(config.allowed_algorithms.len(), 1);
922 assert_eq!(config.allowed_algorithms[0], Algorithm::RS256);
923 }
924
925 #[test]
930 fn test_jwks_structure() {
931 let jwks = Jwks { keys: vec![] };
932 assert_eq!(jwks.keys.len(), 0);
933 }
934
935 #[test]
936 fn test_jwk_structure() {
937 let jwk = Jwk {
938 kid: "key-1".to_string(),
939 kty: "RSA".to_string(),
940 key_use: Some("sig".to_string()),
941 alg: Some("RS256".to_string()),
942 n: Some("modulus".to_string()),
943 e: Some("AQAB".to_string()),
944 };
945 assert_eq!(jwk.kid, "key-1");
946 assert_eq!(jwk.kty, "RSA");
947 }
948
949 #[tokio::test]
954 async fn test_authenticate_empty_token() {
955 let backend = OAuth2Backend::new(OAuth2Config::default());
956 let username = Username::new("user@example.com".to_string()).unwrap();
957 let result = backend.authenticate(&username, "").await;
958 assert!(result.is_ok());
959 assert!(!result.unwrap());
960 }
961
962 #[tokio::test]
963 async fn test_authenticate_invalid_token() {
964 let backend = OAuth2Backend::new(OAuth2Config::default());
965 let username = Username::new("user@example.com".to_string()).unwrap();
966 let result = backend.authenticate(&username, "invalid-token").await;
967 assert!(result.is_ok());
968 assert!(!result.unwrap());
969 }
970
971 #[test]
972 fn test_config_with_all_options() {
973 let config = OAuth2Config {
974 provider: OidcProvider::Google {
975 client_id: "client".to_string(),
976 client_secret: "secret".to_string(),
977 },
978 introspection_endpoint: Some("https://oauth.example.com/introspect".to_string()),
979 jwks_cache_ttl: 1800,
980 enable_refresh_tokens: false,
981 allowed_algorithms: vec![Algorithm::RS256, Algorithm::RS384],
982 };
983
984 assert!(config.introspection_endpoint.is_some());
985 assert_eq!(config.jwks_cache_ttl, 1800);
986 assert!(!config.enable_refresh_tokens);
987 assert_eq!(config.allowed_algorithms.len(), 2);
988 }
989
990 #[tokio::test]
995 async fn test_verify_identity_invalid_username() {
996 let backend = OAuth2Backend::new(OAuth2Config::default());
997 let username = Username::new("nonexistent@example.com".to_string()).unwrap();
999 let result = backend.verify_identity(&username).await;
1000 assert!(result.is_ok());
1001 assert!(!result.unwrap());
1002 }
1003
1004 #[tokio::test]
1009 async fn test_concurrent_cache_access() {
1010 let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1011
1012 let mut handles = vec![];
1013 for i in 0..10 {
1014 let backend = Arc::clone(&backend);
1015 let handle = tokio::spawn(async move {
1016 let mut cache = backend.token_cache.write().await;
1017 cache.insert(
1018 format!("user{}@example.com", i),
1019 TokenCacheEntry {
1020 username: format!("user{}@example.com", i),
1021 expires_at: SystemTime::now() + Duration::from_secs(300),
1022 },
1023 );
1024 });
1025 handles.push(handle);
1026 }
1027
1028 for handle in handles {
1029 handle.await.unwrap();
1030 }
1031
1032 let cache = backend.token_cache.read().await;
1033 assert_eq!(cache.len(), 10);
1034 }
1035
1036 #[tokio::test]
1041 async fn test_introspect_without_endpoint() {
1042 let backend = OAuth2Backend::new(OAuth2Config::default());
1043 let result = backend.introspect_token("test-token").await;
1044 assert!(result.is_err());
1045 assert!(result.unwrap_err().to_string().contains("not configured"));
1046 }
1047
1048 #[tokio::test]
1049 async fn test_refresh_token_disabled() {
1050 let config = OAuth2Config {
1051 enable_refresh_tokens: false,
1052 ..Default::default()
1053 };
1054 let backend = OAuth2Backend::new(config);
1055 let result = backend.refresh_token("refresh-token").await;
1056 assert!(result.is_err());
1057 assert!(result.unwrap_err().to_string().contains("not enabled"));
1058 }
1059
1060 #[test]
1065 fn test_parse_xoauth2_response_valid() {
1066 let response =
1067 OAuth2Backend::encode_xoauth2_response("user@example.com", "ya29.a0AfH6SMBx...");
1068 let result = OAuth2Backend::parse_xoauth2_response(&response);
1069 assert!(result.is_ok());
1070 let (username, token) = result.unwrap();
1071 assert_eq!(username, "user@example.com");
1072 assert_eq!(token, "ya29.a0AfH6SMBx...");
1073 }
1074
1075 #[test]
1076 fn test_encode_xoauth2_response() {
1077 let encoded = OAuth2Backend::encode_xoauth2_response("test@example.com", "token123");
1078 assert!(!encoded.is_empty());
1079
1080 let (username, token) = OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1082 assert_eq!(username, "test@example.com");
1083 assert_eq!(token, "token123");
1084 }
1085
1086 #[test]
1087 fn test_parse_xoauth2_response_invalid_base64() {
1088 let result = OAuth2Backend::parse_xoauth2_response("not-valid-base64!");
1089 assert!(result.is_err());
1090 assert!(result.unwrap_err().to_string().contains("decode"));
1091 }
1092
1093 #[test]
1094 fn test_parse_xoauth2_response_missing_username() {
1095 let invalid = BASE64.encode(b"auth=Bearer token123\x01\x01");
1097 let result = OAuth2Backend::parse_xoauth2_response(&invalid);
1098 assert!(result.is_err());
1099 assert!(result.unwrap_err().to_string().contains("username"));
1100 }
1101
1102 #[test]
1103 fn test_parse_xoauth2_response_missing_token() {
1104 let invalid = BASE64.encode(b"user=test@example.com\x01\x01");
1106 let result = OAuth2Backend::parse_xoauth2_response(&invalid);
1107 assert!(result.is_err());
1108 assert!(result.unwrap_err().to_string().contains("token"));
1109 }
1110
1111 #[test]
1112 fn test_xoauth2_round_trip() {
1113 let original_username = "roundtrip@example.com";
1114 let original_token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9...";
1115
1116 let encoded = OAuth2Backend::encode_xoauth2_response(original_username, original_token);
1117 let (decoded_username, decoded_token) =
1118 OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1119
1120 assert_eq!(decoded_username, original_username);
1121 assert_eq!(decoded_token, original_token);
1122 }
1123
1124 #[test]
1125 fn test_xoauth2_special_characters() {
1126 let username = "user+tag@example.com";
1127 let token = "token-with-special_chars.123";
1128
1129 let encoded = OAuth2Backend::encode_xoauth2_response(username, token);
1130 let (decoded_username, decoded_token) =
1131 OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1132
1133 assert_eq!(decoded_username, username);
1134 assert_eq!(decoded_token, token);
1135 }
1136
1137 #[tokio::test]
1142 async fn test_cleanup_expired_tokens() {
1143 let backend = OAuth2Backend::new(OAuth2Config::default());
1144
1145 {
1146 let mut cache = backend.token_cache.write().await;
1147 cache.insert(
1149 "expired@example.com".to_string(),
1150 TokenCacheEntry {
1151 username: "expired@example.com".to_string(),
1152 expires_at: SystemTime::now() - Duration::from_secs(10),
1153 },
1154 );
1155 cache.insert(
1157 "valid@example.com".to_string(),
1158 TokenCacheEntry {
1159 username: "valid@example.com".to_string(),
1160 expires_at: SystemTime::now() + Duration::from_secs(300),
1161 },
1162 );
1163 }
1164
1165 backend.cleanup_expired_tokens().await;
1166
1167 let cache = backend.token_cache.read().await;
1168 assert_eq!(cache.len(), 1);
1169 assert!(cache.contains_key("valid@example.com"));
1170 assert!(!cache.contains_key("expired@example.com"));
1171 }
1172
1173 #[tokio::test]
1174 async fn test_token_cache_size() {
1175 let backend = OAuth2Backend::new(OAuth2Config::default());
1176
1177 {
1178 let mut cache = backend.token_cache.write().await;
1179 for i in 1..=3 {
1180 cache.insert(
1181 format!("user{}@example.com", i),
1182 TokenCacheEntry {
1183 username: format!("user{}@example.com", i),
1184 expires_at: SystemTime::now() + Duration::from_secs(300),
1185 },
1186 );
1187 }
1188 }
1189
1190 let size = backend.token_cache_size().await;
1191 assert_eq!(size, 3);
1192 }
1193
1194 #[tokio::test]
1195 async fn test_invalidate_token() {
1196 let backend = OAuth2Backend::new(OAuth2Config::default());
1197
1198 {
1199 let mut cache = backend.token_cache.write().await;
1200 cache.insert(
1201 "user@example.com".to_string(),
1202 TokenCacheEntry {
1203 username: "user@example.com".to_string(),
1204 expires_at: SystemTime::now() + Duration::from_secs(300),
1205 },
1206 );
1207 }
1208
1209 assert_eq!(backend.token_cache_size().await, 1);
1210
1211 backend.invalidate_token("user@example.com").await;
1212
1213 assert_eq!(backend.token_cache_size().await, 0);
1214 }
1215
1216 #[tokio::test]
1217 async fn test_clear_jwks_cache() {
1218 let backend = OAuth2Backend::new(OAuth2Config::default());
1219
1220 {
1221 let mut cache = backend.jwks_cache.write().await;
1222 *cache = Some((Jwks { keys: vec![] }, SystemTime::now()));
1223 }
1224
1225 backend.clear_jwks_cache().await;
1226
1227 let cache = backend.jwks_cache.read().await;
1228 assert!(cache.is_none());
1229 }
1230
1231 #[test]
1236 fn test_google_jwks_url() {
1237 let config = OAuth2Config {
1238 provider: OidcProvider::Google {
1239 client_id: "client".to_string(),
1240 client_secret: "secret".to_string(),
1241 },
1242 ..Default::default()
1243 };
1244 let backend = OAuth2Backend::new(config);
1245
1246 assert!(matches!(
1248 backend.config.provider,
1249 OidcProvider::Google { .. }
1250 ));
1251 }
1252
1253 #[test]
1254 fn test_microsoft_urls() {
1255 let tenant_id = "tenant-abc-123";
1256 let provider = OidcProvider::Microsoft {
1257 tenant_id: tenant_id.to_string(),
1258 client_id: "client".to_string(),
1259 client_secret: "secret".to_string(),
1260 };
1261
1262 if let OidcProvider::Microsoft { tenant_id: tid, .. } = &provider {
1263 let expected_jwks = format!(
1264 "https://login.microsoftonline.com/{}/discovery/v2.0/keys",
1265 tid
1266 );
1267 assert!(expected_jwks.contains(tenant_id));
1268 }
1269 }
1270
1271 #[test]
1272 fn test_generic_provider_urls() {
1273 let issuer = "https://auth.company.com";
1274 let jwks_url = "https://auth.company.com/.well-known/jwks.json";
1275
1276 let provider = OidcProvider::Generic {
1277 issuer_url: issuer.to_string(),
1278 client_id: "client".to_string(),
1279 client_secret: "secret".to_string(),
1280 jwks_url: jwks_url.to_string(),
1281 };
1282
1283 if let OidcProvider::Generic {
1284 issuer_url,
1285 jwks_url: jwks,
1286 ..
1287 } = &provider
1288 {
1289 assert_eq!(issuer_url, issuer);
1290 assert_eq!(jwks, jwks_url);
1291 }
1292 }
1293
1294 #[tokio::test]
1299 async fn test_multiple_cleanup_calls() {
1300 let backend = OAuth2Backend::new(OAuth2Config::default());
1301
1302 {
1303 let mut cache = backend.token_cache.write().await;
1304 cache.insert(
1305 "expired@example.com".to_string(),
1306 TokenCacheEntry {
1307 username: "expired@example.com".to_string(),
1308 expires_at: SystemTime::now() - Duration::from_secs(10),
1309 },
1310 );
1311 }
1312
1313 backend.cleanup_expired_tokens().await;
1315 backend.cleanup_expired_tokens().await;
1316 backend.cleanup_expired_tokens().await;
1317
1318 let cache = backend.token_cache.read().await;
1319 assert_eq!(cache.len(), 0);
1320 }
1321
1322 #[tokio::test]
1323 async fn test_invalidate_nonexistent_token() {
1324 let backend = OAuth2Backend::new(OAuth2Config::default());
1325 backend.invalidate_token("nonexistent@example.com").await;
1327 assert_eq!(backend.token_cache_size().await, 0);
1328 }
1329
1330 #[test]
1331 fn test_xoauth2_empty_username() {
1332 let encoded = OAuth2Backend::encode_xoauth2_response("", "token");
1333 let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1334 assert!(result.is_ok());
1335 let (username, _) = result.unwrap();
1336 assert_eq!(username, "");
1337 }
1338
1339 #[test]
1340 fn test_xoauth2_empty_token() {
1341 let encoded = OAuth2Backend::encode_xoauth2_response("user@example.com", "");
1342 let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1343 assert!(result.is_ok());
1344 let (_, token) = result.unwrap();
1345 assert_eq!(token, "");
1346 }
1347
1348 #[test]
1349 fn test_xoauth2_long_token() {
1350 let long_token = "a".repeat(1000);
1351 let encoded = OAuth2Backend::encode_xoauth2_response("user@example.com", &long_token);
1352 let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1353 assert!(result.is_ok());
1354 let (_, token) = result.unwrap();
1355 assert_eq!(token.len(), 1000);
1356 }
1357
1358 #[test]
1363 fn test_config_validation_minimal() {
1364 let config = OAuth2Config {
1365 provider: OidcProvider::Generic {
1366 issuer_url: "https://minimal.example.com".to_string(),
1367 client_id: "c".to_string(),
1368 client_secret: "s".to_string(),
1369 jwks_url: "https://minimal.example.com/jwks".to_string(),
1370 },
1371 introspection_endpoint: None,
1372 jwks_cache_ttl: 60,
1373 enable_refresh_tokens: false,
1374 allowed_algorithms: vec![Algorithm::RS256],
1375 };
1376
1377 let backend = OAuth2Backend::new(config);
1378 assert!(backend.config.jwks_cache_ttl >= 60);
1379 }
1380
1381 #[test]
1382 fn test_config_validation_maximal() {
1383 let config = OAuth2Config {
1384 provider: OidcProvider::Google {
1385 client_id: "very-long-client-id-with-many-characters".to_string(),
1386 client_secret: "very-long-secret-with-special-chars!@#$%".to_string(),
1387 },
1388 introspection_endpoint: Some(
1389 "https://oauth.googleapis.com/token/introspect".to_string(),
1390 ),
1391 jwks_cache_ttl: 86400,
1392 enable_refresh_tokens: true,
1393 allowed_algorithms: vec![
1394 Algorithm::RS256,
1395 Algorithm::RS384,
1396 Algorithm::RS512,
1397 Algorithm::ES256,
1398 Algorithm::ES384,
1399 ],
1400 };
1401
1402 let backend = OAuth2Backend::new(config);
1403 assert_eq!(backend.config.allowed_algorithms.len(), 5);
1404 assert!(backend.config.enable_refresh_tokens);
1405 }
1406
1407 #[tokio::test]
1412 async fn test_concurrent_jwks_cache_access() {
1413 let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1414
1415 let mut handles = vec![];
1416 for _ in 0..5 {
1417 let backend = Arc::clone(&backend);
1418 let handle = tokio::spawn(async move {
1419 backend.clear_jwks_cache().await;
1420 });
1421 handles.push(handle);
1422 }
1423
1424 for handle in handles {
1425 handle.await.unwrap();
1426 }
1427
1428 let cache = backend.jwks_cache.read().await;
1429 assert!(cache.is_none());
1430 }
1431
1432 #[tokio::test]
1433 async fn test_concurrent_cleanup() {
1434 let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1435
1436 {
1437 let mut cache = backend.token_cache.write().await;
1438 for i in 0..100 {
1439 cache.insert(
1440 format!("user{}@example.com", i),
1441 TokenCacheEntry {
1442 username: format!("user{}@example.com", i),
1443 expires_at: if i % 2 == 0 {
1444 SystemTime::now() + Duration::from_secs(300)
1445 } else {
1446 SystemTime::now() - Duration::from_secs(10)
1447 },
1448 },
1449 );
1450 }
1451 }
1452
1453 let mut handles = vec![];
1454 for _ in 0..10 {
1455 let backend = Arc::clone(&backend);
1456 let handle = tokio::spawn(async move {
1457 backend.cleanup_expired_tokens().await;
1458 });
1459 handles.push(handle);
1460 }
1461
1462 for handle in handles {
1463 handle.await.unwrap();
1464 }
1465
1466 let cache = backend.token_cache.read().await;
1467 assert_eq!(cache.len(), 50);
1468 }
1469}