Skip to main content

rusmes_auth/backends/
oauth2.rs

1//! OAuth2/OIDC authentication backend
2
3use crate::AuthBackend;
4use async_trait::async_trait;
5use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
6use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
7use rusmes_proto::Username;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use tokio::sync::RwLock;
13
14/// OIDC provider configuration
15#[derive(Debug, Clone)]
16pub enum OidcProvider {
17    /// Google OAuth2
18    Google {
19        client_id: String,
20        client_secret: String,
21    },
22    /// Microsoft Azure AD
23    Microsoft {
24        tenant_id: String,
25        client_id: String,
26        client_secret: String,
27    },
28    /// Generic OIDC provider
29    Generic {
30        issuer_url: String,
31        client_id: String,
32        client_secret: String,
33        jwks_url: String,
34    },
35}
36
37/// JWT claims structure
38#[derive(Debug, Serialize, Deserialize)]
39struct Claims {
40    sub: String,
41    email: Option<String>,
42    exp: u64,
43    iat: u64,
44    iss: String,
45    aud: String,
46}
47
48/// Token introspection response
49#[derive(Debug, Deserialize)]
50#[allow(dead_code)]
51struct IntrospectionResponse {
52    active: bool,
53    #[serde(default)]
54    username: Option<String>,
55    #[serde(default)]
56    email: Option<String>,
57    #[serde(default)]
58    exp: Option<u64>,
59}
60
61/// JWKS (JSON Web Key Set) structure
62#[derive(Debug, Clone, Deserialize)]
63struct Jwks {
64    keys: Vec<Jwk>,
65}
66
67/// JSON Web Key
68#[derive(Debug, Clone, Deserialize)]
69#[allow(dead_code)]
70struct Jwk {
71    kid: String,
72    kty: String,
73    #[serde(rename = "use")]
74    key_use: Option<String>,
75    alg: Option<String>,
76    n: Option<String>,
77    e: Option<String>,
78}
79
80/// Token cache entry
81#[derive(Debug, Clone)]
82#[allow(dead_code)]
83struct TokenCacheEntry {
84    username: String,
85    expires_at: SystemTime,
86}
87
88/// OAuth2/OIDC configuration
89#[derive(Debug, Clone)]
90pub struct OAuth2Config {
91    /// OIDC provider
92    pub provider: OidcProvider,
93    /// Token introspection endpoint
94    pub introspection_endpoint: Option<String>,
95    /// JWKS cache TTL in seconds
96    pub jwks_cache_ttl: u64,
97    /// Enable refresh token support
98    pub enable_refresh_tokens: bool,
99    /// Allowed algorithms for JWT validation
100    pub allowed_algorithms: Vec<Algorithm>,
101}
102
103impl Default for OAuth2Config {
104    fn default() -> Self {
105        Self {
106            provider: OidcProvider::Generic {
107                issuer_url: "https://example.com".to_string(),
108                client_id: "client-id".to_string(),
109                client_secret: "client-secret".to_string(),
110                jwks_url: "https://example.com/.well-known/jwks.json".to_string(),
111            },
112            introspection_endpoint: None,
113            jwks_cache_ttl: 3600,
114            enable_refresh_tokens: true,
115            allowed_algorithms: vec![Algorithm::RS256],
116        }
117    }
118}
119
120/// OAuth2/OIDC authentication backend
121pub struct OAuth2Backend {
122    config: OAuth2Config,
123    token_cache: Arc<RwLock<HashMap<String, TokenCacheEntry>>>,
124    jwks_cache: Arc<RwLock<Option<(Jwks, SystemTime)>>>,
125    client: reqwest::Client,
126}
127
128impl OAuth2Backend {
129    /// Create a new OAuth2 authentication backend
130    pub fn new(config: OAuth2Config) -> Self {
131        Self {
132            config,
133            token_cache: Arc::new(RwLock::new(HashMap::new())),
134            jwks_cache: Arc::new(RwLock::new(None)),
135            client: reqwest::Client::new(),
136        }
137    }
138
139    /// Parse XOAUTH2 SASL initial response
140    ///
141    /// Format: `base64(user=<username>\x01auth=Bearer <token>\x01\x01)`
142    pub fn parse_xoauth2_response(response: &str) -> anyhow::Result<(String, String)> {
143        // Decode base64
144        let decoded = BASE64
145            .decode(response.as_bytes())
146            .map_err(|e| anyhow::anyhow!("Failed to decode XOAUTH2 response: {}", e))?;
147
148        let decoded_str = String::from_utf8(decoded)
149            .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in XOAUTH2 response: {}", e))?;
150
151        // Split by \x01
152        let parts: Vec<&str> = decoded_str.split('\x01').collect();
153
154        // Extract username and token
155        let mut username = None;
156        let mut token = None;
157
158        for part in &parts {
159            if part.starts_with("user=") {
160                username = part.strip_prefix("user=").map(|s| s.to_string());
161            } else if part.starts_with("auth=Bearer ") {
162                token = part.strip_prefix("auth=Bearer ").map(|s| s.to_string());
163            }
164        }
165
166        let username = username.ok_or_else(|| anyhow::anyhow!("Missing username in XOAUTH2"))?;
167        let token = token.ok_or_else(|| anyhow::anyhow!("Missing token in XOAUTH2"))?;
168
169        Ok((username, token))
170    }
171
172    /// Encode XOAUTH2 SASL initial response
173    ///
174    /// Format: `base64(user=<username>\x01auth=Bearer <token>\x01\x01)`
175    #[allow(dead_code)]
176    pub fn encode_xoauth2_response(username: &str, token: &str) -> String {
177        let response = format!("user={}\x01auth=Bearer {}\x01\x01", username, token);
178        BASE64.encode(response.as_bytes())
179    }
180
181    /// Clear expired entries from token cache
182    pub async fn cleanup_expired_tokens(&self) {
183        let mut cache = self.token_cache.write().await;
184        let now = SystemTime::now();
185        cache.retain(|_, entry| entry.expires_at > now);
186    }
187
188    /// Get token cache size
189    #[allow(dead_code)]
190    pub async fn token_cache_size(&self) -> usize {
191        let cache = self.token_cache.read().await;
192        cache.len()
193    }
194
195    /// Invalidate cached token for a user
196    #[allow(dead_code)]
197    pub async fn invalidate_token(&self, username: &str) {
198        let mut cache = self.token_cache.write().await;
199        cache.remove(username);
200    }
201
202    /// Clear JWKS cache (force refresh on next validation)
203    #[allow(dead_code)]
204    pub async fn clear_jwks_cache(&self) {
205        let mut cache = self.jwks_cache.write().await;
206        *cache = None;
207    }
208
209    /// Get JWKS from provider
210    async fn get_jwks(&self) -> anyhow::Result<Jwks> {
211        // Check cache first
212        {
213            let cache = self.jwks_cache.read().await;
214            if let Some((jwks, cached_at)) = &*cache {
215                if cached_at.elapsed().unwrap_or(Duration::MAX).as_secs()
216                    < self.config.jwks_cache_ttl
217                {
218                    return Ok(jwks.clone());
219                }
220            }
221        }
222
223        // Fetch from provider
224        let jwks_url = match &self.config.provider {
225            OidcProvider::Google { .. } => "https://www.googleapis.com/oauth2/v3/certs",
226            OidcProvider::Microsoft { tenant_id, .. } => &format!(
227                "https://login.microsoftonline.com/{}/discovery/v2.0/keys",
228                tenant_id
229            ),
230            OidcProvider::Generic { jwks_url, .. } => jwks_url.as_str(),
231        };
232
233        let jwks: Jwks = self.client.get(jwks_url).send().await?.json().await?;
234
235        // Update cache
236        {
237            let mut cache = self.jwks_cache.write().await;
238            *cache = Some((jwks.clone(), SystemTime::now()));
239        }
240
241        Ok(jwks)
242    }
243
244    /// Validate JWT token
245    async fn validate_jwt(&self, token: &str) -> anyhow::Result<Claims> {
246        // Decode header to get kid
247        let header = decode_header(token)?;
248        let kid = header
249            .kid
250            .ok_or_else(|| anyhow::anyhow!("No kid in JWT header"))?;
251
252        // Get JWKS
253        let jwks = self.get_jwks().await?;
254
255        // Find matching key
256        let jwk = jwks
257            .keys
258            .iter()
259            .find(|k| k.kid == kid)
260            .ok_or_else(|| anyhow::anyhow!("No matching key found in JWKS"))?;
261
262        // Construct RSA public key from JWK
263        let n = jwk
264            .n
265            .as_ref()
266            .ok_or_else(|| anyhow::anyhow!("Missing n in JWK"))?;
267        let e = jwk
268            .e
269            .as_ref()
270            .ok_or_else(|| anyhow::anyhow!("Missing e in JWK"))?;
271
272        let n_bytes = BASE64.decode(n)?;
273        let e_bytes = BASE64.decode(e)?;
274
275        // Create decoding key
276        let decoding_key =
277            DecodingKey::from_rsa_components(&BASE64.encode(&n_bytes), &BASE64.encode(&e_bytes))?;
278
279        // Validate token
280        let mut validation = Validation::new(Algorithm::RS256);
281        validation.algorithms = self.config.allowed_algorithms.clone();
282
283        // Set expected audience based on provider
284        let expected_aud = match &self.config.provider {
285            OidcProvider::Google { client_id, .. } => client_id.clone(),
286            OidcProvider::Microsoft { client_id, .. } => client_id.clone(),
287            OidcProvider::Generic { client_id, .. } => client_id.clone(),
288        };
289        validation.set_audience(&[&expected_aud]);
290
291        let token_data = decode::<Claims>(token, &decoding_key, &validation)?;
292
293        Ok(token_data.claims)
294    }
295
296    /// Introspect token at provider's introspection endpoint
297    async fn introspect_token(&self, token: &str) -> anyhow::Result<IntrospectionResponse> {
298        let endpoint = self
299            .config
300            .introspection_endpoint
301            .as_ref()
302            .ok_or_else(|| anyhow::anyhow!("Token introspection endpoint not configured"))?;
303
304        let (client_id, client_secret) = match &self.config.provider {
305            OidcProvider::Google {
306                client_id,
307                client_secret,
308            } => (client_id, client_secret),
309            OidcProvider::Microsoft {
310                client_id,
311                client_secret,
312                ..
313            } => (client_id, client_secret),
314            OidcProvider::Generic {
315                client_id,
316                client_secret,
317                ..
318            } => (client_id, client_secret),
319        };
320
321        let mut params = HashMap::new();
322        params.insert("token", token);
323        params.insert("client_id", client_id);
324        params.insert("client_secret", client_secret);
325
326        let response = self
327            .client
328            .post(endpoint)
329            .form(&params)
330            .send()
331            .await?
332            .json::<IntrospectionResponse>()
333            .await?;
334
335        Ok(response)
336    }
337
338    /// Authenticate using XOAUTH2 SASL mechanism
339    async fn xoauth2_authenticate(&self, token: &str) -> anyhow::Result<String> {
340        // Try JWT validation first
341        if let Ok(claims) = self.validate_jwt(token).await {
342            return Ok(claims.email.or(Some(claims.sub)).unwrap_or_default());
343        }
344
345        // Fall back to token introspection
346        let introspection = self.introspect_token(token).await?;
347
348        if !introspection.active {
349            return Err(anyhow::anyhow!("Token is not active"));
350        }
351
352        introspection
353            .email
354            .or(introspection.username)
355            .ok_or_else(|| anyhow::anyhow!("No username in token"))
356    }
357
358    /// Refresh access token using refresh token
359    #[allow(dead_code)]
360    async fn refresh_token(&self, refresh_token: &str) -> anyhow::Result<String> {
361        if !self.config.enable_refresh_tokens {
362            return Err(anyhow::anyhow!("Refresh tokens not enabled"));
363        }
364
365        let token_endpoint = match &self.config.provider {
366            OidcProvider::Google { .. } => "https://oauth2.googleapis.com/token",
367            OidcProvider::Microsoft { tenant_id, .. } => &format!(
368                "https://login.microsoftonline.com/{}/oauth2/v2.0/token",
369                tenant_id
370            ),
371            OidcProvider::Generic { issuer_url, .. } => &format!("{}/token", issuer_url),
372        };
373
374        let (client_id, client_secret) = match &self.config.provider {
375            OidcProvider::Google {
376                client_id,
377                client_secret,
378            } => (client_id, client_secret),
379            OidcProvider::Microsoft {
380                client_id,
381                client_secret,
382                ..
383            } => (client_id, client_secret),
384            OidcProvider::Generic {
385                client_id,
386                client_secret,
387                ..
388            } => (client_id, client_secret),
389        };
390
391        let mut params = HashMap::new();
392        params.insert("grant_type", "refresh_token");
393        params.insert("refresh_token", refresh_token);
394        params.insert("client_id", client_id);
395        params.insert("client_secret", client_secret);
396
397        #[derive(Deserialize)]
398        struct TokenResponse {
399            access_token: String,
400        }
401
402        let response = self
403            .client
404            .post(token_endpoint)
405            .form(&params)
406            .send()
407            .await?
408            .json::<TokenResponse>()
409            .await?;
410
411        Ok(response.access_token)
412    }
413}
414
415#[async_trait]
416impl AuthBackend for OAuth2Backend {
417    async fn authenticate(&self, username: &Username, password: &str) -> anyhow::Result<bool> {
418        // In OAuth2 flow, "password" is the access token
419        let token = password;
420
421        // Check cache first
422        {
423            let cache = self.token_cache.read().await;
424            if let Some(entry) = cache.get(&username.to_string()) {
425                if SystemTime::now() < entry.expires_at {
426                    return Ok(true);
427                }
428            }
429        }
430
431        // Validate token and get username
432        match self.xoauth2_authenticate(token).await {
433            Ok(token_username) => {
434                if token_username == username.to_string() {
435                    // Cache successful authentication
436                    let mut cache = self.token_cache.write().await;
437                    cache.insert(
438                        username.to_string(),
439                        TokenCacheEntry {
440                            username: token_username,
441                            expires_at: SystemTime::now() + Duration::from_secs(300),
442                        },
443                    );
444                    Ok(true)
445                } else {
446                    Ok(false)
447                }
448            }
449            Err(_) => Ok(false),
450        }
451    }
452
453    async fn verify_identity(&self, username: &Username) -> anyhow::Result<bool> {
454        // Check cache
455        let cache = self.token_cache.read().await;
456        Ok(cache.contains_key(&username.to_string()))
457    }
458
459    async fn list_users(&self) -> anyhow::Result<Vec<Username>> {
460        // OAuth2 backends don't maintain a user list
461        let cache = self.token_cache.read().await;
462        Ok(cache
463            .keys()
464            .filter_map(|k| Username::new(k.clone()).ok())
465            .collect())
466    }
467
468    async fn create_user(&self, _username: &Username, _password: &str) -> anyhow::Result<()> {
469        Err(anyhow::anyhow!(
470            "OAuth2 backend does not support user creation (external provider)"
471        ))
472    }
473
474    async fn delete_user(&self, _username: &Username) -> anyhow::Result<()> {
475        Err(anyhow::anyhow!(
476            "OAuth2 backend does not support user deletion (external provider)"
477        ))
478    }
479
480    async fn change_password(
481        &self,
482        _username: &Username,
483        _new_password: &str,
484    ) -> anyhow::Result<()> {
485        Err(anyhow::anyhow!(
486            "OAuth2 backend does not support password changes (external provider)"
487        ))
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    // ========================================================================
496    // Configuration Tests
497    // ========================================================================
498
499    #[test]
500    fn test_oauth2_config_default() {
501        let config = OAuth2Config::default();
502        assert_eq!(config.jwks_cache_ttl, 3600);
503        assert!(config.enable_refresh_tokens);
504        assert_eq!(config.allowed_algorithms.len(), 1);
505    }
506
507    #[test]
508    fn test_oauth2_config_google() {
509        let config = OAuth2Config {
510            provider: OidcProvider::Google {
511                client_id: "test-client-id".to_string(),
512                client_secret: "test-secret".to_string(),
513            },
514            ..Default::default()
515        };
516        assert!(matches!(config.provider, OidcProvider::Google { .. }));
517    }
518
519    #[test]
520    fn test_oauth2_config_microsoft() {
521        let config = OAuth2Config {
522            provider: OidcProvider::Microsoft {
523                tenant_id: "test-tenant".to_string(),
524                client_id: "test-client".to_string(),
525                client_secret: "test-secret".to_string(),
526            },
527            ..Default::default()
528        };
529        assert!(matches!(config.provider, OidcProvider::Microsoft { .. }));
530    }
531
532    #[test]
533    fn test_oauth2_config_generic() {
534        let config = OAuth2Config {
535            provider: OidcProvider::Generic {
536                issuer_url: "https://oidc.example.com".to_string(),
537                client_id: "client".to_string(),
538                client_secret: "secret".to_string(),
539                jwks_url: "https://oidc.example.com/jwks".to_string(),
540            },
541            ..Default::default()
542        };
543        assert!(matches!(config.provider, OidcProvider::Generic { .. }));
544    }
545
546    #[test]
547    fn test_allowed_algorithms() {
548        let config = OAuth2Config {
549            allowed_algorithms: vec![Algorithm::RS256, Algorithm::RS384, Algorithm::RS512],
550            ..Default::default()
551        };
552        assert_eq!(config.allowed_algorithms.len(), 3);
553    }
554
555    #[test]
556    fn test_introspection_endpoint_optional() {
557        let config = OAuth2Config::default();
558        assert!(config.introspection_endpoint.is_none());
559
560        let config_with_introspection = OAuth2Config {
561            introspection_endpoint: Some("https://example.com/introspect".to_string()),
562            ..Default::default()
563        };
564        assert!(config_with_introspection.introspection_endpoint.is_some());
565    }
566
567    #[test]
568    fn test_refresh_tokens_enabled() {
569        let config = OAuth2Config {
570            enable_refresh_tokens: true,
571            ..Default::default()
572        };
573        assert!(config.enable_refresh_tokens);
574
575        let config_disabled = OAuth2Config {
576            enable_refresh_tokens: false,
577            ..Default::default()
578        };
579        assert!(!config_disabled.enable_refresh_tokens);
580    }
581
582    #[test]
583    fn test_jwks_cache_ttl() {
584        let config = OAuth2Config {
585            jwks_cache_ttl: 7200,
586            ..Default::default()
587        };
588        assert_eq!(config.jwks_cache_ttl, 7200);
589    }
590
591    #[test]
592    fn test_config_clone() {
593        let config = OAuth2Config::default();
594        let cloned = config.clone();
595        assert_eq!(config.jwks_cache_ttl, cloned.jwks_cache_ttl);
596    }
597
598    // ========================================================================
599    // Backend Creation Tests
600    // ========================================================================
601
602    #[tokio::test]
603    async fn test_oauth2_backend_creation() {
604        let config = OAuth2Config::default();
605        let backend = OAuth2Backend::new(config);
606        let cache = backend.token_cache.read().await;
607        assert_eq!(cache.len(), 0);
608    }
609
610    #[tokio::test]
611    async fn test_token_cache_empty_on_creation() {
612        let backend = OAuth2Backend::new(OAuth2Config::default());
613        let cache = backend.token_cache.read().await;
614        assert!(cache.is_empty());
615    }
616
617    #[tokio::test]
618    async fn test_jwks_cache_empty_on_creation() {
619        let backend = OAuth2Backend::new(OAuth2Config::default());
620        let cache = backend.jwks_cache.read().await;
621        assert!(cache.is_none());
622    }
623
624    // ========================================================================
625    // AuthBackend Trait Tests
626    // ========================================================================
627
628    #[tokio::test]
629    async fn test_create_user_not_supported() {
630        let backend = OAuth2Backend::new(OAuth2Config::default());
631        let username = Username::new("user@example.com".to_string()).unwrap();
632        let result = backend.create_user(&username, "token").await;
633        assert!(result.is_err());
634        assert!(result
635            .unwrap_err()
636            .to_string()
637            .contains("external provider"));
638    }
639
640    #[tokio::test]
641    async fn test_delete_user_not_supported() {
642        let backend = OAuth2Backend::new(OAuth2Config::default());
643        let username = Username::new("user@example.com".to_string()).unwrap();
644        let result = backend.delete_user(&username).await;
645        assert!(result.is_err());
646        assert!(result
647            .unwrap_err()
648            .to_string()
649            .contains("external provider"));
650    }
651
652    #[tokio::test]
653    async fn test_change_password_not_supported() {
654        let backend = OAuth2Backend::new(OAuth2Config::default());
655        let username = Username::new("user@example.com".to_string()).unwrap();
656        let result = backend.change_password(&username, "newtoken").await;
657        assert!(result.is_err());
658        assert!(result
659            .unwrap_err()
660            .to_string()
661            .contains("external provider"));
662    }
663
664    #[tokio::test]
665    async fn test_list_users_empty() {
666        let backend = OAuth2Backend::new(OAuth2Config::default());
667        let users = backend.list_users().await.unwrap();
668        assert_eq!(users.len(), 0);
669    }
670
671    #[tokio::test]
672    async fn test_verify_identity_not_cached() {
673        let backend = OAuth2Backend::new(OAuth2Config::default());
674        let username = Username::new("user@example.com".to_string()).unwrap();
675        let verified = backend.verify_identity(&username).await.unwrap();
676        assert!(!verified);
677    }
678
679    #[tokio::test]
680    async fn test_verify_identity_cached() {
681        let backend = OAuth2Backend::new(OAuth2Config::default());
682        let username = Username::new("cached@example.com".to_string()).unwrap();
683
684        {
685            let mut cache = backend.token_cache.write().await;
686            cache.insert(
687                username.to_string(),
688                TokenCacheEntry {
689                    username: username.to_string(),
690                    expires_at: SystemTime::now() + Duration::from_secs(300),
691                },
692            );
693        }
694
695        let verified = backend.verify_identity(&username).await.unwrap();
696        assert!(verified);
697    }
698
699    // ========================================================================
700    // Token Cache Tests
701    // ========================================================================
702
703    #[tokio::test]
704    async fn test_token_cache_insertion() {
705        let backend = OAuth2Backend::new(OAuth2Config::default());
706
707        {
708            let mut cache = backend.token_cache.write().await;
709            cache.insert(
710                "user@example.com".to_string(),
711                TokenCacheEntry {
712                    username: "user@example.com".to_string(),
713                    expires_at: SystemTime::now() + Duration::from_secs(300),
714                },
715            );
716        }
717
718        let cache = backend.token_cache.read().await;
719        assert_eq!(cache.len(), 1);
720        assert!(cache.contains_key("user@example.com"));
721    }
722
723    #[tokio::test]
724    async fn test_token_cache_expiration() {
725        let backend = OAuth2Backend::new(OAuth2Config::default());
726
727        {
728            let mut cache = backend.token_cache.write().await;
729            cache.insert(
730                "expired@example.com".to_string(),
731                TokenCacheEntry {
732                    username: "expired@example.com".to_string(),
733                    expires_at: SystemTime::now() - Duration::from_secs(1),
734                },
735            );
736        }
737
738        // Cache contains entry but it's expired
739        let cache = backend.token_cache.read().await;
740        let entry = cache.get("expired@example.com").unwrap();
741        assert!(entry.expires_at < SystemTime::now());
742    }
743
744    #[tokio::test]
745    async fn test_token_cache_multiple_users() {
746        let backend = OAuth2Backend::new(OAuth2Config::default());
747
748        {
749            let mut cache = backend.token_cache.write().await;
750            for i in 1..=5 {
751                cache.insert(
752                    format!("user{}@example.com", i),
753                    TokenCacheEntry {
754                        username: format!("user{}@example.com", i),
755                        expires_at: SystemTime::now() + Duration::from_secs(300),
756                    },
757                );
758            }
759        }
760
761        let cache = backend.token_cache.read().await;
762        assert_eq!(cache.len(), 5);
763    }
764
765    #[tokio::test]
766    async fn test_list_users_with_cached_tokens() {
767        let backend = OAuth2Backend::new(OAuth2Config::default());
768
769        {
770            let mut cache = backend.token_cache.write().await;
771            cache.insert(
772                "user1@example.com".to_string(),
773                TokenCacheEntry {
774                    username: "user1@example.com".to_string(),
775                    expires_at: SystemTime::now() + Duration::from_secs(300),
776                },
777            );
778            cache.insert(
779                "user2@example.com".to_string(),
780                TokenCacheEntry {
781                    username: "user2@example.com".to_string(),
782                    expires_at: SystemTime::now() + Duration::from_secs(300),
783                },
784            );
785        }
786
787        let users = backend.list_users().await.unwrap();
788        assert_eq!(users.len(), 2);
789    }
790
791    // ========================================================================
792    // Claims Structure Tests
793    // ========================================================================
794
795    #[test]
796    fn test_claims_structure() {
797        let claims = Claims {
798            sub: "user123".to_string(),
799            email: Some("user@example.com".to_string()),
800            exp: 1234567890,
801            iat: 1234567800,
802            iss: "https://accounts.google.com".to_string(),
803            aud: "client-id".to_string(),
804        };
805        assert_eq!(claims.sub, "user123");
806        assert_eq!(claims.email.unwrap(), "user@example.com");
807    }
808
809    #[test]
810    fn test_claims_without_email() {
811        let claims = Claims {
812            sub: "user123".to_string(),
813            email: None,
814            exp: 1234567890,
815            iat: 1234567800,
816            iss: "https://accounts.google.com".to_string(),
817            aud: "client-id".to_string(),
818        };
819        assert_eq!(claims.sub, "user123");
820        assert!(claims.email.is_none());
821    }
822
823    // ========================================================================
824    // Token Cache Entry Tests
825    // ========================================================================
826
827    #[test]
828    fn test_token_cache_entry() {
829        let entry = TokenCacheEntry {
830            username: "user@example.com".to_string(),
831            expires_at: SystemTime::now() + Duration::from_secs(300),
832        };
833        assert_eq!(entry.username, "user@example.com");
834        assert!(entry.expires_at > SystemTime::now());
835    }
836
837    #[test]
838    fn test_token_cache_entry_expired() {
839        let entry = TokenCacheEntry {
840            username: "user@example.com".to_string(),
841            expires_at: SystemTime::now() - Duration::from_secs(10),
842        };
843        assert!(entry.expires_at < SystemTime::now());
844    }
845
846    // ========================================================================
847    // Provider-specific Tests
848    // ========================================================================
849
850    #[test]
851    fn test_google_provider_config() {
852        let provider = OidcProvider::Google {
853            client_id: "google-client-id".to_string(),
854            client_secret: "google-secret".to_string(),
855        };
856
857        if let OidcProvider::Google { client_id, .. } = &provider {
858            assert_eq!(client_id, "google-client-id");
859        } else {
860            panic!("Expected Google provider");
861        }
862    }
863
864    #[test]
865    fn test_microsoft_provider_config() {
866        let provider = OidcProvider::Microsoft {
867            tenant_id: "tenant-123".to_string(),
868            client_id: "ms-client-id".to_string(),
869            client_secret: "ms-secret".to_string(),
870        };
871
872        if let OidcProvider::Microsoft { tenant_id, .. } = &provider {
873            assert_eq!(tenant_id, "tenant-123");
874        } else {
875            panic!("Expected Microsoft provider");
876        }
877    }
878
879    #[test]
880    fn test_generic_provider_config() {
881        let provider = OidcProvider::Generic {
882            issuer_url: "https://auth.example.com".to_string(),
883            client_id: "generic-client".to_string(),
884            client_secret: "generic-secret".to_string(),
885            jwks_url: "https://auth.example.com/.well-known/jwks.json".to_string(),
886        };
887
888        if let OidcProvider::Generic { issuer_url, .. } = &provider {
889            assert_eq!(issuer_url, "https://auth.example.com");
890        } else {
891            panic!("Expected Generic provider");
892        }
893    }
894
895    // ========================================================================
896    // Algorithm Tests
897    // ========================================================================
898
899    #[test]
900    fn test_multiple_allowed_algorithms() {
901        let config = OAuth2Config {
902            allowed_algorithms: vec![
903                Algorithm::RS256,
904                Algorithm::RS384,
905                Algorithm::RS512,
906                Algorithm::ES256,
907            ],
908            ..Default::default()
909        };
910        assert_eq!(config.allowed_algorithms.len(), 4);
911        assert!(config.allowed_algorithms.contains(&Algorithm::RS256));
912        assert!(config.allowed_algorithms.contains(&Algorithm::ES256));
913    }
914
915    #[test]
916    fn test_single_algorithm_rs256() {
917        let config = OAuth2Config {
918            allowed_algorithms: vec![Algorithm::RS256],
919            ..Default::default()
920        };
921        assert_eq!(config.allowed_algorithms.len(), 1);
922        assert_eq!(config.allowed_algorithms[0], Algorithm::RS256);
923    }
924
925    // ========================================================================
926    // JWKS Structure Tests
927    // ========================================================================
928
929    #[test]
930    fn test_jwks_structure() {
931        let jwks = Jwks { keys: vec![] };
932        assert_eq!(jwks.keys.len(), 0);
933    }
934
935    #[test]
936    fn test_jwk_structure() {
937        let jwk = Jwk {
938            kid: "key-1".to_string(),
939            kty: "RSA".to_string(),
940            key_use: Some("sig".to_string()),
941            alg: Some("RS256".to_string()),
942            n: Some("modulus".to_string()),
943            e: Some("AQAB".to_string()),
944        };
945        assert_eq!(jwk.kid, "key-1");
946        assert_eq!(jwk.kty, "RSA");
947    }
948
949    // ========================================================================
950    // Edge Cases
951    // ========================================================================
952
953    #[tokio::test]
954    async fn test_authenticate_empty_token() {
955        let backend = OAuth2Backend::new(OAuth2Config::default());
956        let username = Username::new("user@example.com".to_string()).unwrap();
957        let result = backend.authenticate(&username, "").await;
958        assert!(result.is_ok());
959        assert!(!result.unwrap());
960    }
961
962    #[tokio::test]
963    async fn test_authenticate_invalid_token() {
964        let backend = OAuth2Backend::new(OAuth2Config::default());
965        let username = Username::new("user@example.com".to_string()).unwrap();
966        let result = backend.authenticate(&username, "invalid-token").await;
967        assert!(result.is_ok());
968        assert!(!result.unwrap());
969    }
970
971    #[test]
972    fn test_config_with_all_options() {
973        let config = OAuth2Config {
974            provider: OidcProvider::Google {
975                client_id: "client".to_string(),
976                client_secret: "secret".to_string(),
977            },
978            introspection_endpoint: Some("https://oauth.example.com/introspect".to_string()),
979            jwks_cache_ttl: 1800,
980            enable_refresh_tokens: false,
981            allowed_algorithms: vec![Algorithm::RS256, Algorithm::RS384],
982        };
983
984        assert!(config.introspection_endpoint.is_some());
985        assert_eq!(config.jwks_cache_ttl, 1800);
986        assert!(!config.enable_refresh_tokens);
987        assert_eq!(config.allowed_algorithms.len(), 2);
988    }
989
990    // ========================================================================
991    // Username Validation Tests
992    // ========================================================================
993
994    #[tokio::test]
995    async fn test_verify_identity_invalid_username() {
996        let backend = OAuth2Backend::new(OAuth2Config::default());
997        // Create a valid username that's not in cache
998        let username = Username::new("nonexistent@example.com".to_string()).unwrap();
999        let result = backend.verify_identity(&username).await;
1000        assert!(result.is_ok());
1001        assert!(!result.unwrap());
1002    }
1003
1004    // ========================================================================
1005    // Concurrent Access Tests
1006    // ========================================================================
1007
1008    #[tokio::test]
1009    async fn test_concurrent_cache_access() {
1010        let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1011
1012        let mut handles = vec![];
1013        for i in 0..10 {
1014            let backend = Arc::clone(&backend);
1015            let handle = tokio::spawn(async move {
1016                let mut cache = backend.token_cache.write().await;
1017                cache.insert(
1018                    format!("user{}@example.com", i),
1019                    TokenCacheEntry {
1020                        username: format!("user{}@example.com", i),
1021                        expires_at: SystemTime::now() + Duration::from_secs(300),
1022                    },
1023                );
1024            });
1025            handles.push(handle);
1026        }
1027
1028        for handle in handles {
1029            handle.await.unwrap();
1030        }
1031
1032        let cache = backend.token_cache.read().await;
1033        assert_eq!(cache.len(), 10);
1034    }
1035
1036    // ========================================================================
1037    // Error Handling Tests
1038    // ========================================================================
1039
1040    #[tokio::test]
1041    async fn test_introspect_without_endpoint() {
1042        let backend = OAuth2Backend::new(OAuth2Config::default());
1043        let result = backend.introspect_token("test-token").await;
1044        assert!(result.is_err());
1045        assert!(result.unwrap_err().to_string().contains("not configured"));
1046    }
1047
1048    #[tokio::test]
1049    async fn test_refresh_token_disabled() {
1050        let config = OAuth2Config {
1051            enable_refresh_tokens: false,
1052            ..Default::default()
1053        };
1054        let backend = OAuth2Backend::new(config);
1055        let result = backend.refresh_token("refresh-token").await;
1056        assert!(result.is_err());
1057        assert!(result.unwrap_err().to_string().contains("not enabled"));
1058    }
1059
1060    // ========================================================================
1061    // XOAUTH2 SASL Mechanism Tests
1062    // ========================================================================
1063
1064    #[test]
1065    fn test_parse_xoauth2_response_valid() {
1066        let response =
1067            OAuth2Backend::encode_xoauth2_response("user@example.com", "ya29.a0AfH6SMBx...");
1068        let result = OAuth2Backend::parse_xoauth2_response(&response);
1069        assert!(result.is_ok());
1070        let (username, token) = result.unwrap();
1071        assert_eq!(username, "user@example.com");
1072        assert_eq!(token, "ya29.a0AfH6SMBx...");
1073    }
1074
1075    #[test]
1076    fn test_encode_xoauth2_response() {
1077        let encoded = OAuth2Backend::encode_xoauth2_response("test@example.com", "token123");
1078        assert!(!encoded.is_empty());
1079
1080        // Verify it can be decoded back
1081        let (username, token) = OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1082        assert_eq!(username, "test@example.com");
1083        assert_eq!(token, "token123");
1084    }
1085
1086    #[test]
1087    fn test_parse_xoauth2_response_invalid_base64() {
1088        let result = OAuth2Backend::parse_xoauth2_response("not-valid-base64!");
1089        assert!(result.is_err());
1090        assert!(result.unwrap_err().to_string().contains("decode"));
1091    }
1092
1093    #[test]
1094    fn test_parse_xoauth2_response_missing_username() {
1095        // Create response without username
1096        let invalid = BASE64.encode(b"auth=Bearer token123\x01\x01");
1097        let result = OAuth2Backend::parse_xoauth2_response(&invalid);
1098        assert!(result.is_err());
1099        assert!(result.unwrap_err().to_string().contains("username"));
1100    }
1101
1102    #[test]
1103    fn test_parse_xoauth2_response_missing_token() {
1104        // Create response without token
1105        let invalid = BASE64.encode(b"user=test@example.com\x01\x01");
1106        let result = OAuth2Backend::parse_xoauth2_response(&invalid);
1107        assert!(result.is_err());
1108        assert!(result.unwrap_err().to_string().contains("token"));
1109    }
1110
1111    #[test]
1112    fn test_xoauth2_round_trip() {
1113        let original_username = "roundtrip@example.com";
1114        let original_token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9...";
1115
1116        let encoded = OAuth2Backend::encode_xoauth2_response(original_username, original_token);
1117        let (decoded_username, decoded_token) =
1118            OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1119
1120        assert_eq!(decoded_username, original_username);
1121        assert_eq!(decoded_token, original_token);
1122    }
1123
1124    #[test]
1125    fn test_xoauth2_special_characters() {
1126        let username = "user+tag@example.com";
1127        let token = "token-with-special_chars.123";
1128
1129        let encoded = OAuth2Backend::encode_xoauth2_response(username, token);
1130        let (decoded_username, decoded_token) =
1131            OAuth2Backend::parse_xoauth2_response(&encoded).unwrap();
1132
1133        assert_eq!(decoded_username, username);
1134        assert_eq!(decoded_token, token);
1135    }
1136
1137    // ========================================================================
1138    // Token Cache Management Tests
1139    // ========================================================================
1140
1141    #[tokio::test]
1142    async fn test_cleanup_expired_tokens() {
1143        let backend = OAuth2Backend::new(OAuth2Config::default());
1144
1145        {
1146            let mut cache = backend.token_cache.write().await;
1147            // Add expired token
1148            cache.insert(
1149                "expired@example.com".to_string(),
1150                TokenCacheEntry {
1151                    username: "expired@example.com".to_string(),
1152                    expires_at: SystemTime::now() - Duration::from_secs(10),
1153                },
1154            );
1155            // Add valid token
1156            cache.insert(
1157                "valid@example.com".to_string(),
1158                TokenCacheEntry {
1159                    username: "valid@example.com".to_string(),
1160                    expires_at: SystemTime::now() + Duration::from_secs(300),
1161                },
1162            );
1163        }
1164
1165        backend.cleanup_expired_tokens().await;
1166
1167        let cache = backend.token_cache.read().await;
1168        assert_eq!(cache.len(), 1);
1169        assert!(cache.contains_key("valid@example.com"));
1170        assert!(!cache.contains_key("expired@example.com"));
1171    }
1172
1173    #[tokio::test]
1174    async fn test_token_cache_size() {
1175        let backend = OAuth2Backend::new(OAuth2Config::default());
1176
1177        {
1178            let mut cache = backend.token_cache.write().await;
1179            for i in 1..=3 {
1180                cache.insert(
1181                    format!("user{}@example.com", i),
1182                    TokenCacheEntry {
1183                        username: format!("user{}@example.com", i),
1184                        expires_at: SystemTime::now() + Duration::from_secs(300),
1185                    },
1186                );
1187            }
1188        }
1189
1190        let size = backend.token_cache_size().await;
1191        assert_eq!(size, 3);
1192    }
1193
1194    #[tokio::test]
1195    async fn test_invalidate_token() {
1196        let backend = OAuth2Backend::new(OAuth2Config::default());
1197
1198        {
1199            let mut cache = backend.token_cache.write().await;
1200            cache.insert(
1201                "user@example.com".to_string(),
1202                TokenCacheEntry {
1203                    username: "user@example.com".to_string(),
1204                    expires_at: SystemTime::now() + Duration::from_secs(300),
1205                },
1206            );
1207        }
1208
1209        assert_eq!(backend.token_cache_size().await, 1);
1210
1211        backend.invalidate_token("user@example.com").await;
1212
1213        assert_eq!(backend.token_cache_size().await, 0);
1214    }
1215
1216    #[tokio::test]
1217    async fn test_clear_jwks_cache() {
1218        let backend = OAuth2Backend::new(OAuth2Config::default());
1219
1220        {
1221            let mut cache = backend.jwks_cache.write().await;
1222            *cache = Some((Jwks { keys: vec![] }, SystemTime::now()));
1223        }
1224
1225        backend.clear_jwks_cache().await;
1226
1227        let cache = backend.jwks_cache.read().await;
1228        assert!(cache.is_none());
1229    }
1230
1231    // ========================================================================
1232    // Provider URL Tests
1233    // ========================================================================
1234
1235    #[test]
1236    fn test_google_jwks_url() {
1237        let config = OAuth2Config {
1238            provider: OidcProvider::Google {
1239                client_id: "client".to_string(),
1240                client_secret: "secret".to_string(),
1241            },
1242            ..Default::default()
1243        };
1244        let backend = OAuth2Backend::new(config);
1245
1246        // Google JWKS URL is hardcoded in get_jwks method
1247        assert!(matches!(
1248            backend.config.provider,
1249            OidcProvider::Google { .. }
1250        ));
1251    }
1252
1253    #[test]
1254    fn test_microsoft_urls() {
1255        let tenant_id = "tenant-abc-123";
1256        let provider = OidcProvider::Microsoft {
1257            tenant_id: tenant_id.to_string(),
1258            client_id: "client".to_string(),
1259            client_secret: "secret".to_string(),
1260        };
1261
1262        if let OidcProvider::Microsoft { tenant_id: tid, .. } = &provider {
1263            let expected_jwks = format!(
1264                "https://login.microsoftonline.com/{}/discovery/v2.0/keys",
1265                tid
1266            );
1267            assert!(expected_jwks.contains(tenant_id));
1268        }
1269    }
1270
1271    #[test]
1272    fn test_generic_provider_urls() {
1273        let issuer = "https://auth.company.com";
1274        let jwks_url = "https://auth.company.com/.well-known/jwks.json";
1275
1276        let provider = OidcProvider::Generic {
1277            issuer_url: issuer.to_string(),
1278            client_id: "client".to_string(),
1279            client_secret: "secret".to_string(),
1280            jwks_url: jwks_url.to_string(),
1281        };
1282
1283        if let OidcProvider::Generic {
1284            issuer_url,
1285            jwks_url: jwks,
1286            ..
1287        } = &provider
1288        {
1289            assert_eq!(issuer_url, issuer);
1290            assert_eq!(jwks, jwks_url);
1291        }
1292    }
1293
1294    // ========================================================================
1295    // Additional Edge Cases
1296    // ========================================================================
1297
1298    #[tokio::test]
1299    async fn test_multiple_cleanup_calls() {
1300        let backend = OAuth2Backend::new(OAuth2Config::default());
1301
1302        {
1303            let mut cache = backend.token_cache.write().await;
1304            cache.insert(
1305                "expired@example.com".to_string(),
1306                TokenCacheEntry {
1307                    username: "expired@example.com".to_string(),
1308                    expires_at: SystemTime::now() - Duration::from_secs(10),
1309                },
1310            );
1311        }
1312
1313        // Multiple cleanup calls should be safe
1314        backend.cleanup_expired_tokens().await;
1315        backend.cleanup_expired_tokens().await;
1316        backend.cleanup_expired_tokens().await;
1317
1318        let cache = backend.token_cache.read().await;
1319        assert_eq!(cache.len(), 0);
1320    }
1321
1322    #[tokio::test]
1323    async fn test_invalidate_nonexistent_token() {
1324        let backend = OAuth2Backend::new(OAuth2Config::default());
1325        // Should not panic
1326        backend.invalidate_token("nonexistent@example.com").await;
1327        assert_eq!(backend.token_cache_size().await, 0);
1328    }
1329
1330    #[test]
1331    fn test_xoauth2_empty_username() {
1332        let encoded = OAuth2Backend::encode_xoauth2_response("", "token");
1333        let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1334        assert!(result.is_ok());
1335        let (username, _) = result.unwrap();
1336        assert_eq!(username, "");
1337    }
1338
1339    #[test]
1340    fn test_xoauth2_empty_token() {
1341        let encoded = OAuth2Backend::encode_xoauth2_response("user@example.com", "");
1342        let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1343        assert!(result.is_ok());
1344        let (_, token) = result.unwrap();
1345        assert_eq!(token, "");
1346    }
1347
1348    #[test]
1349    fn test_xoauth2_long_token() {
1350        let long_token = "a".repeat(1000);
1351        let encoded = OAuth2Backend::encode_xoauth2_response("user@example.com", &long_token);
1352        let result = OAuth2Backend::parse_xoauth2_response(&encoded);
1353        assert!(result.is_ok());
1354        let (_, token) = result.unwrap();
1355        assert_eq!(token.len(), 1000);
1356    }
1357
1358    // ========================================================================
1359    // Configuration Validation Tests
1360    // ========================================================================
1361
1362    #[test]
1363    fn test_config_validation_minimal() {
1364        let config = OAuth2Config {
1365            provider: OidcProvider::Generic {
1366                issuer_url: "https://minimal.example.com".to_string(),
1367                client_id: "c".to_string(),
1368                client_secret: "s".to_string(),
1369                jwks_url: "https://minimal.example.com/jwks".to_string(),
1370            },
1371            introspection_endpoint: None,
1372            jwks_cache_ttl: 60,
1373            enable_refresh_tokens: false,
1374            allowed_algorithms: vec![Algorithm::RS256],
1375        };
1376
1377        let backend = OAuth2Backend::new(config);
1378        assert!(backend.config.jwks_cache_ttl >= 60);
1379    }
1380
1381    #[test]
1382    fn test_config_validation_maximal() {
1383        let config = OAuth2Config {
1384            provider: OidcProvider::Google {
1385                client_id: "very-long-client-id-with-many-characters".to_string(),
1386                client_secret: "very-long-secret-with-special-chars!@#$%".to_string(),
1387            },
1388            introspection_endpoint: Some(
1389                "https://oauth.googleapis.com/token/introspect".to_string(),
1390            ),
1391            jwks_cache_ttl: 86400,
1392            enable_refresh_tokens: true,
1393            allowed_algorithms: vec![
1394                Algorithm::RS256,
1395                Algorithm::RS384,
1396                Algorithm::RS512,
1397                Algorithm::ES256,
1398                Algorithm::ES384,
1399            ],
1400        };
1401
1402        let backend = OAuth2Backend::new(config);
1403        assert_eq!(backend.config.allowed_algorithms.len(), 5);
1404        assert!(backend.config.enable_refresh_tokens);
1405    }
1406
1407    // ========================================================================
1408    // Thread Safety Tests
1409    // ========================================================================
1410
1411    #[tokio::test]
1412    async fn test_concurrent_jwks_cache_access() {
1413        let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1414
1415        let mut handles = vec![];
1416        for _ in 0..5 {
1417            let backend = Arc::clone(&backend);
1418            let handle = tokio::spawn(async move {
1419                backend.clear_jwks_cache().await;
1420            });
1421            handles.push(handle);
1422        }
1423
1424        for handle in handles {
1425            handle.await.unwrap();
1426        }
1427
1428        let cache = backend.jwks_cache.read().await;
1429        assert!(cache.is_none());
1430    }
1431
1432    #[tokio::test]
1433    async fn test_concurrent_cleanup() {
1434        let backend = Arc::new(OAuth2Backend::new(OAuth2Config::default()));
1435
1436        {
1437            let mut cache = backend.token_cache.write().await;
1438            for i in 0..100 {
1439                cache.insert(
1440                    format!("user{}@example.com", i),
1441                    TokenCacheEntry {
1442                        username: format!("user{}@example.com", i),
1443                        expires_at: if i % 2 == 0 {
1444                            SystemTime::now() + Duration::from_secs(300)
1445                        } else {
1446                            SystemTime::now() - Duration::from_secs(10)
1447                        },
1448                    },
1449                );
1450            }
1451        }
1452
1453        let mut handles = vec![];
1454        for _ in 0..10 {
1455            let backend = Arc::clone(&backend);
1456            let handle = tokio::spawn(async move {
1457                backend.cleanup_expired_tokens().await;
1458            });
1459            handles.push(handle);
1460        }
1461
1462        for handle in handles {
1463            handle.await.unwrap();
1464        }
1465
1466        let cache = backend.token_cache.read().await;
1467        assert_eq!(cache.len(), 50);
1468    }
1469}