nblm_core/auth/oauth/
mod.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use oauth2::{
9    basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointSet,
10    PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse,
11    TokenResponse as OAuth2TokenResponse, TokenUrl,
12};
13use parking_lot::RwLock;
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use time::OffsetDateTime;
17
18use crate::auth::{ProviderKind, TokenProvider};
19use crate::env::ApiProfile;
20use crate::error::{Error, Result};
21
22pub mod loopback;
23
24// ============================================================================
25// OAuth Configuration
26// ============================================================================
27
28/// OAuth2 configuration for Authorization Code Flow with PKCE
29#[derive(Debug, Clone)]
30pub struct OAuthConfig {
31    pub auth_endpoint: String,
32    pub token_endpoint: String,
33    pub client_id: String,
34    pub client_secret: Option<String>,
35    pub redirect_uri: String,
36    pub scopes: Vec<String>,
37    pub audience: Option<String>,
38    pub additional_params: HashMap<String, String>,
39}
40
41impl OAuthConfig {
42    pub const DEFAULT_REDIRECT_URI: &str = "http://127.0.0.1:4317";
43    const AUTH_ENDPOINT: &str = "https://accounts.google.com/o/oauth2/v2/auth";
44    const TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
45    const SCOPE_CLOUD_PLATFORM: &str = "https://www.googleapis.com/auth/cloud-platform";
46    const SCOPE_DRIVE_FILE: &str = "https://www.googleapis.com/auth/drive.file";
47
48    /// Create a default Google OAuth2 configuration for NotebookLM Enterprise
49    pub fn google_default(_project_number: &str) -> Result<Self> {
50        let client_id = std::env::var("NBLM_OAUTH_CLIENT_ID").map_err(|_| {
51            Error::TokenProvider(
52                "NBLM_OAUTH_CLIENT_ID is required for user OAuth authentication".to_string(),
53            )
54        })?;
55
56        let audience = std::env::var("NBLM_OAUTH_AUDIENCE").ok();
57
58        Ok(Self {
59            auth_endpoint: Self::AUTH_ENDPOINT.to_string(),
60            token_endpoint: Self::TOKEN_ENDPOINT.to_string(),
61            client_id,
62            client_secret: std::env::var("NBLM_OAUTH_CLIENT_SECRET").ok(),
63            redirect_uri: std::env::var("NBLM_OAUTH_REDIRECT_URI")
64                .unwrap_or_else(|_| Self::DEFAULT_REDIRECT_URI.to_string()),
65            scopes: vec![
66                Self::SCOPE_CLOUD_PLATFORM.to_string(),
67                Self::SCOPE_DRIVE_FILE.to_string(),
68            ],
69            audience,
70            additional_params: HashMap::new(),
71        })
72    }
73}
74
75// ============================================================================
76// OAuth Tokens
77// ============================================================================
78
79/// OAuth2 tokens returned from token endpoint
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct OAuthTokens {
82    pub access_token: String,
83    pub refresh_token: Option<String>,
84    pub expires_at: OffsetDateTime,
85    pub scope: Option<String>,
86    pub token_type: String,
87}
88
89impl OAuthTokens {
90    /// Create OAuthTokens from oauth2-rs token response
91    pub fn from_oauth2_response(
92        response: StandardTokenResponse<
93            oauth2::EmptyExtraTokenFields,
94            oauth2::basic::BasicTokenType,
95        >,
96        issued_at: OffsetDateTime,
97    ) -> Self {
98        let expires_at = issued_at
99            + response
100                .expires_in()
101                .map(|d| Duration::from_secs(d.as_secs()))
102                .unwrap_or_else(|| Duration::from_secs(3600));
103
104        let scope = response.scopes().map(|scopes| {
105            scopes
106                .iter()
107                .map(|s| s.as_str().to_string())
108                .collect::<Vec<_>>()
109                .join(" ")
110        });
111
112        Self {
113            access_token: response.access_token().secret().to_string(),
114            refresh_token: response.refresh_token().map(|rt| rt.secret().to_string()),
115            expires_at,
116            scope,
117            token_type: match response.token_type() {
118                oauth2::basic::BasicTokenType::Bearer => "Bearer".to_string(),
119                oauth2::basic::BasicTokenType::Mac => "MAC".to_string(),
120                oauth2::basic::BasicTokenType::Extension(s) => s.clone(),
121            },
122        }
123    }
124}
125
126// ============================================================================
127// Token Cache Entry
128// ============================================================================
129
130/// In-memory cache entry for OAuth tokens
131#[derive(Debug, Clone)]
132pub struct TokenCacheEntry {
133    pub tokens: OAuthTokens,
134    pub refresh_margin: Duration,
135}
136
137impl TokenCacheEntry {
138    pub fn new(tokens: OAuthTokens) -> Self {
139        Self {
140            tokens,
141            refresh_margin: Duration::from_secs(60), // Default 60 seconds
142        }
143    }
144
145    /// Check if token needs refresh
146    pub fn needs_refresh(&self, now: OffsetDateTime) -> bool {
147        now >= (self.tokens.expires_at - self.refresh_margin)
148    }
149}
150
151// ============================================================================
152// Token Store Key
153// ============================================================================
154
155/// Key for storing tokens in RefreshTokenStore
156#[derive(Debug, Clone, PartialEq, Eq, Hash)]
157pub struct TokenStoreKey {
158    pub profile: ApiProfile,
159    pub project_number: Option<String>,
160    pub endpoint_location: Option<String>,
161    pub user_hint: Option<String>,
162}
163
164impl fmt::Display for TokenStoreKey {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        let mut parts = vec![self.profile.as_str().to_string()];
167
168        if let Some(ref project) = self.project_number {
169            parts.push(format!("project={}", project));
170        }
171
172        if let Some(ref location) = self.endpoint_location {
173            parts.push(format!("location={}", location));
174        }
175
176        if let Some(ref user) = self.user_hint {
177            parts.push(format!("user={}", user));
178        }
179
180        write!(f, "{}", parts.join(":"))
181    }
182}
183
184// ============================================================================
185// Serialized Tokens (for storage)
186// ============================================================================
187
188/// Serialized token data for storage
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct SerializedTokens {
191    pub refresh_token: String,
192    pub scopes: Vec<String>,
193    pub expires_at: Option<OffsetDateTime>,
194    pub token_type: String,
195    #[serde(with = "time::serde::rfc3339")]
196    pub updated_at: OffsetDateTime,
197}
198
199/// Credentials file format
200#[derive(Debug, Serialize, Deserialize)]
201struct CredentialsFile {
202    version: u32,
203    entries: HashMap<String, SerializedTokens>,
204}
205
206impl CredentialsFile {
207    fn new() -> Self {
208        Self {
209            version: 1,
210            entries: HashMap::new(),
211        }
212    }
213}
214
215// ============================================================================
216// RefreshTokenStore Trait
217// ============================================================================
218
219/// Trait for storing and retrieving refresh tokens
220#[async_trait]
221pub trait RefreshTokenStore: Send + Sync {
222    /// Load tokens for the given key
223    async fn load(&self, key: &TokenStoreKey) -> Result<Option<SerializedTokens>>;
224
225    /// Save tokens for the given key
226    async fn save(&self, key: &TokenStoreKey, tokens: &SerializedTokens) -> Result<()>;
227
228    /// Delete tokens for the given key
229    async fn delete(&self, key: &TokenStoreKey) -> Result<()>;
230}
231
232// ============================================================================
233// FileRefreshTokenStore
234// ============================================================================
235
236/// File-based implementation of RefreshTokenStore
237pub struct FileRefreshTokenStore {
238    file_path: std::path::PathBuf,
239}
240
241impl FileRefreshTokenStore {
242    /// Create a new FileRefreshTokenStore
243    pub fn new() -> Result<Self> {
244        let dirs = directories::ProjectDirs::from("com", "nblm", "nblm-rs")
245            .ok_or_else(|| Error::TokenProvider("failed to find config directory".to_string()))?;
246
247        let config_dir = dirs.config_dir();
248        let file_path = config_dir.join("credentials.json");
249
250        Self::from_path(file_path)
251    }
252
253    /// Create a store backed by an explicit credentials file path.
254    pub fn from_path(path: impl Into<PathBuf>) -> Result<Self> {
255        Ok(Self {
256            file_path: path.into(),
257        })
258    }
259
260    /// Ensure config directory exists with proper permissions (async)
261    async fn ensure_config_dir(&self) -> Result<()> {
262        if let Some(config_dir) = self.file_path.parent() {
263            tokio::fs::create_dir_all(config_dir).await.map_err(|e| {
264                Error::TokenProvider(format!("failed to create config directory: {}", e))
265            })?;
266
267            #[cfg(unix)]
268            {
269                use std::os::unix::fs::PermissionsExt;
270                let mut perms = tokio::fs::metadata(config_dir)
271                    .await
272                    .map_err(|e| {
273                        Error::TokenProvider(format!("failed to get config dir metadata: {}", e))
274                    })?
275                    .permissions();
276                perms.set_mode(0o700);
277                tokio::fs::set_permissions(config_dir, perms)
278                    .await
279                    .map_err(|e| {
280                        Error::TokenProvider(format!("failed to set config dir permissions: {}", e))
281                    })?;
282            }
283        }
284        Ok(())
285    }
286
287    /// Load credentials file
288    async fn load_file(&self) -> Result<CredentialsFile> {
289        self.ensure_config_dir().await?;
290
291        if !self.file_path.exists() {
292            return Ok(CredentialsFile::new());
293        }
294
295        let content = tokio::fs::read_to_string(&self.file_path)
296            .await
297            .map_err(|e| Error::TokenProvider(format!("failed to read credentials file: {}", e)))?;
298
299        let file: CredentialsFile = serde_json::from_str(&content).map_err(|e| {
300            Error::TokenProvider(format!("failed to parse credentials file: {}", e))
301        })?;
302
303        Ok(file)
304    }
305
306    /// Save credentials file
307    async fn save_file(&self, file: &CredentialsFile) -> Result<()> {
308        self.ensure_config_dir().await?;
309
310        let content = serde_json::to_string_pretty(file)
311            .map_err(|e| Error::TokenProvider(format!("failed to serialize credentials: {}", e)))?;
312
313        // Write to temp file first, then rename (atomic write)
314        // Use a unique temporary file name to avoid conflicts in concurrent writes
315        let random_suffix = {
316            use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
317            use rand::RngCore;
318            let mut rng = rand::rng();
319            let mut random_bytes = [0u8; 8];
320            rng.fill_bytes(&mut random_bytes);
321            URL_SAFE_NO_PAD.encode(random_bytes)
322        };
323
324        let temp_path = self.file_path.with_file_name(format!(
325            "{}.{}.tmp",
326            self.file_path
327                .file_name()
328                .and_then(|n| n.to_str())
329                .unwrap_or("credentials.json"),
330            random_suffix
331        ));
332        tokio::fs::write(&temp_path, content).await.map_err(|e| {
333            Error::TokenProvider(format!("failed to write credentials file: {}", e))
334        })?;
335
336        #[cfg(unix)]
337        {
338            use std::os::unix::fs::PermissionsExt;
339            if let Ok(metadata) = tokio::fs::metadata(&temp_path).await {
340                let mut perms = metadata.permissions();
341                perms.set_mode(0o600);
342                let _ = tokio::fs::set_permissions(&temp_path, perms).await;
343            }
344        }
345
346        tokio::fs::rename(&temp_path, &self.file_path)
347            .await
348            .map_err(|e| Error::TokenProvider(format!("failed to rename temp file: {}", e)))?;
349
350        Ok(())
351    }
352}
353
354#[async_trait]
355impl RefreshTokenStore for FileRefreshTokenStore {
356    async fn load(&self, key: &TokenStoreKey) -> Result<Option<SerializedTokens>> {
357        let file = self.load_file().await?;
358        Ok(file.entries.get(&key.to_string()).cloned())
359    }
360
361    async fn save(&self, key: &TokenStoreKey, tokens: &SerializedTokens) -> Result<()> {
362        let mut file = self.load_file().await?;
363        file.entries.insert(key.to_string(), tokens.clone());
364        self.save_file(&file).await
365    }
366
367    async fn delete(&self, key: &TokenStoreKey) -> Result<()> {
368        let mut file = self.load_file().await?;
369        file.entries.remove(&key.to_string());
370        self.save_file(&file).await
371    }
372}
373
374// ============================================================================
375// OAuth Flow
376// ============================================================================
377
378/// Parameters for building authorization URL
379#[derive(Debug, Clone)]
380pub struct AuthorizeParams {
381    pub state: Option<String>,
382    pub code_challenge: Option<String>,
383    pub code_challenge_method: Option<String>,
384}
385
386/// Context for authorization flow
387#[derive(Debug, Clone)]
388pub struct AuthorizeContext {
389    pub url: String,
390    pub state: String,
391    pub code_verifier: String,
392    pub expires_at: OffsetDateTime,
393}
394
395/// OAuth2 Authorization Code Flow with PKCE
396pub struct OAuthFlow {
397    client: BasicClient<
398        EndpointSet,
399        oauth2::EndpointNotSet,
400        oauth2::EndpointNotSet,
401        oauth2::EndpointNotSet,
402        EndpointSet,
403    >,
404    config: OAuthConfig,
405    http: Arc<Client>,
406}
407
408impl OAuthFlow {
409    /// Create a new OAuthFlow
410    pub fn new(config: OAuthConfig, http: Arc<Client>) -> Result<Self> {
411        let client_id = ClientId::new(config.client_id.clone());
412        let auth_url = AuthUrl::new(config.auth_endpoint.clone())
413            .map_err(|e| Error::TokenProvider(format!("invalid auth_url: {}", e)))?;
414        let token_url = TokenUrl::new(config.token_endpoint.clone())
415            .map_err(|e| Error::TokenProvider(format!("invalid token_url: {}", e)))?;
416        let redirect_url = RedirectUrl::new(config.redirect_uri.clone())
417            .map_err(|e| Error::TokenProvider(format!("invalid redirect_url: {}", e)))?;
418
419        let mut client_builder = BasicClient::new(client_id)
420            .set_auth_uri(auth_url)
421            .set_token_uri(token_url)
422            .set_redirect_uri(redirect_url);
423
424        if let Some(ref client_secret) = config.client_secret {
425            client_builder =
426                client_builder.set_client_secret(ClientSecret::new(client_secret.clone()));
427        }
428
429        Ok(Self {
430            client: client_builder,
431            config,
432            http,
433        })
434    }
435
436    /// Build authorization URL with PKCE
437    pub fn build_authorize_url(&self, params: &AuthorizeParams) -> AuthorizeContext {
438        // Generate PKCE challenge and verifier
439        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
440
441        // Generate CSRF token
442        let csrf_token = if let Some(ref state) = params.state {
443            CsrfToken::new(state.clone())
444        } else {
445            CsrfToken::new_random()
446        };
447
448        // Build authorization URL
449        let mut auth_request = self.client.authorize_url(|| csrf_token.clone());
450
451        // Add scopes
452        for scope_str in &self.config.scopes {
453            auth_request = auth_request.add_scope(Scope::new(scope_str.clone()));
454        }
455
456        // Set PKCE challenge
457        auth_request = auth_request.set_pkce_challenge(pkce_challenge);
458
459        // Add additional parameters
460        for (key, value) in &self.config.additional_params {
461            auth_request = auth_request.add_extra_param(key, value);
462        }
463
464        // Build the URL
465        let (auth_url, csrf_token_actual) = auth_request.url();
466
467        // Add Google-specific parameters
468        let mut url = url::Url::parse(auth_url.as_str()).expect("invalid auth_url");
469        url.query_pairs_mut()
470            .append_pair("access_type", "offline")
471            .append_pair("prompt", "consent");
472
473        if let Some(ref audience) = self.config.audience {
474            url.query_pairs_mut().append_pair("audience", audience);
475        }
476
477        let expires_at = OffsetDateTime::now_utc() + Duration::from_secs(600); // 10 minutes
478
479        AuthorizeContext {
480            url: url.to_string(),
481            state: csrf_token_actual.secret().to_string(),
482            code_verifier: pkce_verifier.secret().to_string(),
483            expires_at,
484        }
485    }
486
487    /// Exchange authorization code for tokens
488    pub async fn exchange_code(
489        &self,
490        context: &AuthorizeContext,
491        code: &str,
492    ) -> Result<OAuthTokens> {
493        let code = AuthorizationCode::new(code.to_string());
494        let pkce_verifier = PkceCodeVerifier::new(context.code_verifier.clone());
495
496        let token_request = self
497            .client
498            .exchange_code(code)
499            .set_pkce_verifier(pkce_verifier);
500
501        let token_response = token_request
502            .request_async(self.http.as_ref())
503            .await
504            .map_err(|e| Error::TokenProvider(format!("oauth token exchange failed: {}", e)))?;
505
506        Ok(OAuthTokens::from_oauth2_response(
507            token_response,
508            OffsetDateTime::now_utc(),
509        ))
510    }
511
512    /// Refresh access token using refresh token
513    pub async fn refresh(&self, refresh_token: &str) -> Result<OAuthTokens> {
514        let refresh_token = RefreshToken::new(refresh_token.to_string());
515
516        let token_request = self.client.exchange_refresh_token(&refresh_token);
517
518        let token_response = token_request
519            .request_async(self.http.as_ref())
520            .await
521            .map_err(|e| Error::TokenProvider(format!("oauth token refresh failed: {}", e)))?;
522
523        Ok(OAuthTokens::from_oauth2_response(
524            token_response,
525            OffsetDateTime::now_utc(),
526        ))
527    }
528
529    /// Revoke refresh token (future use, NOOP for now)
530    pub async fn revoke(&self, _refresh_token: &str) -> Result<()> {
531        // TODO: Implement token revocation using oauth2-rs
532        Ok(())
533    }
534}
535
536// ============================================================================
537// RefreshTokenProvider
538// ============================================================================
539
540/// TokenProvider implementation using refresh tokens
541pub struct RefreshTokenProvider<S: RefreshTokenStore> {
542    flow: OAuthFlow,
543    store: Arc<S>,
544    cache: RwLock<Option<TokenCacheEntry>>,
545    store_key: TokenStoreKey,
546}
547
548impl<S: RefreshTokenStore> RefreshTokenProvider<S> {
549    /// Create a new RefreshTokenProvider
550    pub fn new(flow: OAuthFlow, store: Arc<S>, store_key: TokenStoreKey) -> Self {
551        Self {
552            flow,
553            store,
554            cache: RwLock::new(None),
555            store_key,
556        }
557    }
558
559    /// Ensure tokens are valid, refreshing if necessary
560    async fn ensure_tokens(&self, force_refresh: bool) -> Result<OAuthTokens> {
561        let now = OffsetDateTime::now_utc();
562
563        // Check cache first
564        if !force_refresh {
565            if let Some(ref entry) = *self.cache.read() {
566                if !entry.needs_refresh(now) {
567                    return Ok(entry.tokens.clone());
568                }
569            }
570        }
571
572        // Load refresh token from store
573        let stored = self
574            .store
575            .load(&self.store_key)
576            .await?
577            .ok_or_else(|| Error::TokenProvider("refresh token unavailable".to_string()))?;
578
579        // Refresh access token
580        let tokens = self.flow.refresh(&stored.refresh_token).await?;
581        let refresh_token = tokens
582            .refresh_token
583            .clone()
584            .unwrap_or_else(|| stored.refresh_token.clone());
585        let scopes: Vec<String> = if let Some(scopes_str) = tokens.scope.as_ref() {
586            scopes_str.split_whitespace().map(String::from).collect()
587        } else if !stored.scopes.is_empty() {
588            stored.scopes.clone()
589        } else {
590            Vec::new()
591        };
592        let token_type = if !tokens.token_type.is_empty() {
593            tokens.token_type.clone()
594        } else {
595            stored.token_type.clone()
596        };
597
598        // Update cache
599        {
600            let mut cache = self.cache.write();
601            *cache = Some(TokenCacheEntry::new(tokens.clone()));
602        }
603
604        // Persist refresh token details (preserve previous token when refresh response omits it)
605        let serialized = SerializedTokens {
606            refresh_token,
607            scopes,
608            expires_at: Some(tokens.expires_at),
609            token_type,
610            updated_at: now,
611        };
612        self.store.save(&self.store_key, &serialized).await?;
613
614        Ok(tokens)
615    }
616}
617
618#[async_trait]
619impl<S: RefreshTokenStore> TokenProvider for RefreshTokenProvider<S> {
620    async fn access_token(&self) -> Result<String> {
621        let tokens = self.ensure_tokens(false).await?;
622        Ok(tokens.access_token)
623    }
624
625    async fn refresh_token(&self) -> Result<String> {
626        // Force refresh to get new access token
627        let tokens = self.ensure_tokens(true).await?;
628        Ok(tokens.access_token)
629    }
630
631    fn kind(&self) -> ProviderKind {
632        ProviderKind::UserOauth
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639    use tempfile::tempdir;
640    use wiremock::matchers::{method, path};
641    use wiremock::{Mock, MockServer, ResponseTemplate};
642
643    #[tokio::test]
644    async fn test_token_cache_entry_needs_refresh() {
645        let now = OffsetDateTime::now_utc();
646        let expires_at = now + Duration::from_secs(120); // 2 minutes
647        let tokens = OAuthTokens {
648            access_token: "test-token".to_string(),
649            refresh_token: None,
650            expires_at,
651            scope: None,
652            token_type: "Bearer".to_string(),
653        };
654
655        let entry = TokenCacheEntry::new(tokens);
656
657        // Should not need refresh immediately
658        assert!(!entry.needs_refresh(now));
659
660        // Should need refresh when close to expiry (within margin)
661        let near_expiry = expires_at - Duration::from_secs(30);
662        assert!(entry.needs_refresh(near_expiry));
663    }
664
665    #[tokio::test]
666    async fn test_token_store_key_display() {
667        let key = TokenStoreKey {
668            profile: ApiProfile::Enterprise,
669            project_number: Some("123456".to_string()),
670            endpoint_location: Some("global".to_string()),
671            user_hint: None,
672        };
673
674        let display = key.to_string();
675        assert!(display.contains("enterprise"));
676        assert!(display.contains("project=123456"));
677        assert!(display.contains("location=global"));
678    }
679
680    #[tokio::test]
681    async fn test_oauth_tokens_from_oauth2_response() {
682        use oauth2::StandardTokenResponse;
683
684        let now = OffsetDateTime::now_utc();
685        // Manually create a response with all fields using serde_json
686        let json_response = serde_json::json!({
687            "access_token": "access-token-123",
688            "refresh_token": "refresh-token-456",
689            "expires_in": 3600,
690            "token_type": "Bearer",
691            "scope": "scope1 scope2"
692        });
693
694        let response: StandardTokenResponse<
695            oauth2::EmptyExtraTokenFields,
696            oauth2::basic::BasicTokenType,
697        > = serde_json::from_value(json_response).unwrap();
698
699        let tokens = OAuthTokens::from_oauth2_response(response, now);
700        assert_eq!(tokens.access_token, "access-token-123");
701        assert_eq!(tokens.refresh_token, Some("refresh-token-456".to_string()));
702        assert_eq!(tokens.scope, Some("scope1 scope2".to_string()));
703        assert_eq!(tokens.token_type, "Bearer");
704        assert!(tokens.expires_at > now);
705    }
706
707    #[tokio::test]
708    async fn test_oauth_flow_refresh_token() {
709        let server = MockServer::start().await;
710
711        Mock::given(method("POST"))
712            .and(path("/token"))
713            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
714                "access_token": "new-access-token",
715                "expires_in": 3600,
716                "token_type": "Bearer"
717            })))
718            .mount(&server)
719            .await;
720
721        let config = OAuthConfig {
722            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
723            token_endpoint: format!("{}/token", server.uri()),
724            client_id: "test-client-id".to_string(),
725            client_secret: None,
726            redirect_uri: "http://127.0.0.1:4317".to_string(),
727            scopes: vec!["scope1".to_string()],
728            audience: None,
729            additional_params: HashMap::new(),
730        };
731
732        let http = Arc::new(Client::new());
733        let flow = OAuthFlow::new(config, http).unwrap();
734
735        let tokens = flow.refresh("refresh-token-123").await.unwrap();
736        assert_eq!(tokens.access_token, "new-access-token");
737    }
738
739    // Test 1: PKCE validity tests (security critical)
740    #[test]
741    fn test_pkce_code_verifier_and_challenge_generation() {
742        let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
743
744        // Verify code_verifier is base64url encoded and proper length
745        let verifier_str = verifier.secret();
746        assert!(!verifier_str.is_empty());
747        assert!(verifier_str.len() >= 43 && verifier_str.len() <= 128);
748        assert!(verifier_str
749            .chars()
750            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
751
752        // Verify code_challenge is base64url encoded SHA256 hash
753        let challenge_str = challenge.as_str();
754        assert!(!challenge_str.is_empty());
755        assert_eq!(challenge_str.len(), 43); // SHA256 hash is 32 bytes, base64url is 43 chars
756        assert!(challenge_str
757            .chars()
758            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
759    }
760
761    #[test]
762    fn test_pkce_generates_unique_values() {
763        let (challenge1, verifier1) = PkceCodeChallenge::new_random_sha256();
764        let (challenge2, verifier2) = PkceCodeChallenge::new_random_sha256();
765
766        // Each generation should produce unique values
767        assert_ne!(verifier1.secret(), verifier2.secret());
768        assert_ne!(challenge1.as_str(), challenge2.as_str());
769    }
770
771    #[test]
772    fn test_state_generation_uniqueness() {
773        let state1 = CsrfToken::new_random();
774        let state2 = CsrfToken::new_random();
775
776        // States should be unique
777        assert_ne!(state1.secret(), state2.secret());
778        assert!(!state1.secret().is_empty());
779        assert!(!state2.secret().is_empty());
780    }
781
782    // Test 2: Token expiration boundary tests
783    #[test]
784    fn test_token_needs_refresh_at_exact_expiry() {
785        let now = OffsetDateTime::now_utc();
786        let expires_at = now + Duration::from_secs(60);
787        let tokens = OAuthTokens {
788            access_token: "test-token".to_string(),
789            refresh_token: None,
790            expires_at,
791            scope: None,
792            token_type: "Bearer".to_string(),
793        };
794
795        let entry = TokenCacheEntry::new(tokens);
796
797        // At exact expiry time (accounting for margin)
798        assert!(entry.needs_refresh(expires_at));
799    }
800
801    #[test]
802    fn test_token_needs_refresh_just_before_margin() {
803        let now = OffsetDateTime::now_utc();
804        let expires_at = now + Duration::from_secs(120);
805        let tokens = OAuthTokens {
806            access_token: "test-token".to_string(),
807            refresh_token: None,
808            expires_at,
809            scope: None,
810            token_type: "Bearer".to_string(),
811        };
812
813        let entry = TokenCacheEntry::new(tokens);
814
815        // Just before margin (61 seconds before expiry, margin is 60)
816        let before_margin = expires_at - Duration::from_secs(61);
817        assert!(!entry.needs_refresh(before_margin));
818
819        // Just at margin (60 seconds before expiry)
820        let at_margin = expires_at - Duration::from_secs(60);
821        assert!(entry.needs_refresh(at_margin));
822    }
823
824    #[test]
825    fn test_token_needs_refresh_after_expiry() {
826        let now = OffsetDateTime::now_utc();
827        let expires_at = now - Duration::from_secs(10); // Already expired
828        let tokens = OAuthTokens {
829            access_token: "test-token".to_string(),
830            refresh_token: None,
831            expires_at,
832            scope: None,
833            token_type: "Bearer".to_string(),
834        };
835
836        let entry = TokenCacheEntry::new(tokens);
837        assert!(entry.needs_refresh(now));
838    }
839
840    // Test 3: File store concurrent access tests
841    #[tokio::test]
842    #[serial_test::serial]
843    async fn test_file_store_concurrent_saves() {
844        use std::sync::Arc;
845        use tokio::task::JoinSet;
846
847        let temp_dir = tempdir().unwrap();
848        let store_path = temp_dir.path().join("credentials.json");
849        let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
850        let key = TokenStoreKey {
851            profile: ApiProfile::Enterprise,
852            project_number: Some("concurrent-test".to_string()),
853            endpoint_location: Some("global".to_string()),
854            user_hint: None,
855        };
856
857        let mut join_set = JoinSet::new();
858
859        // Spawn 10 concurrent save operations
860        for i in 0..10 {
861            let store_clone = Arc::clone(&store);
862            let key_clone = key.clone();
863            join_set.spawn(async move {
864                let tokens = SerializedTokens {
865                    refresh_token: format!("token-{}", i),
866                    scopes: vec!["scope1".to_string()],
867                    expires_at: Some(OffsetDateTime::now_utc() + Duration::from_secs(3600)),
868                    token_type: "Bearer".to_string(),
869                    updated_at: OffsetDateTime::now_utc(),
870                };
871                store_clone.save(&key_clone, &tokens).await
872            });
873        }
874
875        // Wait for all operations to complete
876        while let Some(result) = join_set.join_next().await {
877            result.unwrap().unwrap();
878        }
879
880        // Verify that the final state is consistent
881        let loaded = store.load(&key).await.unwrap();
882        assert!(loaded.is_some());
883        let loaded = loaded.unwrap();
884        assert!(loaded.refresh_token.starts_with("token-"));
885
886        // Cleanup
887        store.delete(&key).await.unwrap();
888    }
889
890    #[tokio::test]
891    #[serial_test::serial]
892    async fn test_file_store_atomic_write() {
893        let temp_dir = tempdir().unwrap();
894        let store_path = temp_dir.path().join("credentials.json");
895        let store = FileRefreshTokenStore::from_path(&store_path).unwrap();
896        let key = TokenStoreKey {
897            profile: ApiProfile::Enterprise,
898            project_number: Some("atomic-test".to_string()),
899            endpoint_location: Some("global".to_string()),
900            user_hint: None,
901        };
902
903        let tokens = SerializedTokens {
904            refresh_token: "initial-token".to_string(),
905            scopes: vec!["scope1".to_string()],
906            expires_at: Some(OffsetDateTime::now_utc() + Duration::from_secs(3600)),
907            token_type: "Bearer".to_string(),
908            updated_at: OffsetDateTime::now_utc(),
909        };
910
911        store.save(&key, &tokens).await.unwrap();
912
913        // Verify temp files are cleaned up after save
914        let entries: Vec<_> = std::fs::read_dir(temp_dir.path())
915            .unwrap()
916            .map(|entry| entry.unwrap().file_name())
917            .collect();
918        assert_eq!(entries.len(), 1);
919        assert_eq!(entries[0], std::ffi::OsStr::new("credentials.json"));
920
921        // Cleanup
922        store.delete(&key).await.unwrap();
923    }
924
925    // Test 5: refresh_token omission handling tests
926    #[tokio::test]
927    #[serial_test::serial]
928    async fn test_refresh_token_preserved_when_omitted_in_response() {
929        let server = MockServer::start().await;
930
931        // Mock refresh endpoint that doesn't return refresh_token
932        Mock::given(method("POST"))
933            .and(path("/token"))
934            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
935                "access_token": "new-access-token",
936                "expires_in": 3600,
937                "token_type": "Bearer"
938            })))
939            .mount(&server)
940            .await;
941
942        let config = OAuthConfig {
943            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
944            token_endpoint: format!("{}/token", server.uri()),
945            client_id: "test-client-id".to_string(),
946            client_secret: None,
947            redirect_uri: "http://127.0.0.1:4317".to_string(),
948            scopes: vec!["scope1".to_string()],
949            audience: None,
950            additional_params: HashMap::new(),
951        };
952
953        let http = Arc::new(Client::new());
954        let flow = OAuthFlow::new(config, http).unwrap();
955        let temp_dir = tempdir().unwrap();
956        let store_path = temp_dir.path().join("credentials.json");
957        let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
958        let key = TokenStoreKey {
959            profile: ApiProfile::Enterprise,
960            project_number: Some("omission-test".to_string()),
961            endpoint_location: Some("global".to_string()),
962            user_hint: None,
963        };
964
965        // Store initial refresh token with expired access token
966        let initial_tokens = SerializedTokens {
967            refresh_token: "original-refresh-token".to_string(),
968            scopes: vec!["scope1".to_string()],
969            expires_at: Some(OffsetDateTime::now_utc() - Duration::from_secs(3600)), // Expired
970            token_type: "Bearer".to_string(),
971            updated_at: OffsetDateTime::now_utc() - Duration::from_secs(3600),
972        };
973        store.save(&key, &initial_tokens).await.unwrap();
974
975        let provider = RefreshTokenProvider::new(flow, Arc::clone(&store), key.clone());
976
977        // Get access token (should trigger refresh because token is expired)
978        let access_token = provider.access_token().await.unwrap();
979        assert_eq!(access_token, "new-access-token");
980
981        // Verify original refresh token is preserved
982        let stored = store.load(&key).await.unwrap().unwrap();
983        assert_eq!(stored.refresh_token, "original-refresh-token");
984
985        // Cleanup
986        store.delete(&key).await.unwrap();
987    }
988
989    #[tokio::test]
990    #[serial_test::serial]
991    async fn test_refresh_token_updated_when_included_in_response() {
992        let server = MockServer::start().await;
993
994        // Mock refresh endpoint that returns new refresh_token
995        Mock::given(method("POST"))
996            .and(path("/token"))
997            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
998                "access_token": "new-access-token",
999                "refresh_token": "new-refresh-token",
1000                "expires_in": 3600,
1001                "token_type": "Bearer"
1002            })))
1003            .mount(&server)
1004            .await;
1005
1006        let config = OAuthConfig {
1007            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1008            token_endpoint: format!("{}/token", server.uri()),
1009            client_id: "test-client-id".to_string(),
1010            client_secret: None,
1011            redirect_uri: "http://127.0.0.1:4317".to_string(),
1012            scopes: vec!["scope1".to_string()],
1013            audience: None,
1014            additional_params: HashMap::new(),
1015        };
1016
1017        let http = Arc::new(Client::new());
1018        let flow = OAuthFlow::new(config, http).unwrap();
1019        let temp_dir = tempdir().unwrap();
1020        let store_path = temp_dir.path().join("credentials.json");
1021        let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
1022        let key = TokenStoreKey {
1023            profile: ApiProfile::Enterprise,
1024            project_number: Some("update-test".to_string()),
1025            endpoint_location: Some("global".to_string()),
1026            user_hint: None,
1027        };
1028
1029        // Store initial refresh token with expired access token
1030        let initial_tokens = SerializedTokens {
1031            refresh_token: "original-refresh-token".to_string(),
1032            scopes: vec!["scope1".to_string()],
1033            expires_at: Some(OffsetDateTime::now_utc() - Duration::from_secs(3600)), // Expired
1034            token_type: "Bearer".to_string(),
1035            updated_at: OffsetDateTime::now_utc() - Duration::from_secs(3600),
1036        };
1037        store.save(&key, &initial_tokens).await.unwrap();
1038
1039        let provider = RefreshTokenProvider::new(flow, Arc::clone(&store), key.clone());
1040
1041        // Get access token (should trigger refresh because token is expired)
1042        let access_token = provider.access_token().await.unwrap();
1043        assert_eq!(access_token, "new-access-token");
1044
1045        // Verify refresh token is updated
1046        let stored = store.load(&key).await.unwrap().unwrap();
1047        assert_eq!(stored.refresh_token, "new-refresh-token");
1048
1049        // Cleanup
1050        store.delete(&key).await.unwrap();
1051    }
1052
1053    // Test 6: State validation failure tests (CSRF)
1054    #[tokio::test]
1055    async fn test_state_mismatch_detection() {
1056        let config = OAuthConfig {
1057            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1058            token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1059            client_id: "test-client-id".to_string(),
1060            client_secret: None,
1061            redirect_uri: "http://127.0.0.1:4317".to_string(),
1062            scopes: vec!["scope1".to_string()],
1063            audience: None,
1064            additional_params: HashMap::new(),
1065        };
1066
1067        let http = Arc::new(Client::new());
1068        let flow = OAuthFlow::new(config, http).unwrap();
1069
1070        let context = flow.build_authorize_url(&AuthorizeParams {
1071            state: None,
1072            code_challenge: None,
1073            code_challenge_method: None,
1074        });
1075
1076        // Simulate state mismatch (CSRF attack)
1077        let wrong_state = "attacker-controlled-state";
1078        assert_ne!(context.state, wrong_state);
1079
1080        // In actual implementation, this would be caught by the callback handler
1081        // which compares the received state with context.state
1082    }
1083
1084    #[test]
1085    fn test_authorize_url_contains_required_parameters() {
1086        let config = OAuthConfig {
1087            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1088            token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1089            client_id: "test-client-id".to_string(),
1090            client_secret: None,
1091            redirect_uri: "http://127.0.0.1:4317".to_string(),
1092            scopes: vec!["scope1".to_string(), "scope2".to_string()],
1093            audience: None,
1094            additional_params: HashMap::new(),
1095        };
1096
1097        let http = Arc::new(Client::new());
1098        let flow = OAuthFlow::new(config, http).unwrap();
1099
1100        let context = flow.build_authorize_url(&AuthorizeParams {
1101            state: None,
1102            code_challenge: None,
1103            code_challenge_method: None,
1104        });
1105
1106        let url = url::Url::parse(&context.url).unwrap();
1107        let params: HashMap<_, _> = url.query_pairs().collect();
1108
1109        // Verify required PKCE and OAuth parameters
1110        assert!(params.contains_key("client_id"));
1111        assert!(params.contains_key("redirect_uri"));
1112        assert!(params.contains_key("response_type"));
1113        assert_eq!(params.get("response_type").unwrap(), "code");
1114        assert!(params.contains_key("scope"));
1115        assert!(params.contains_key("state"));
1116        assert!(params.contains_key("code_challenge"));
1117        assert!(params.contains_key("code_challenge_method"));
1118        assert_eq!(params.get("code_challenge_method").unwrap(), "S256");
1119        assert!(params.contains_key("access_type"));
1120        assert_eq!(params.get("access_type").unwrap(), "offline");
1121        assert!(params.contains_key("prompt"));
1122        assert_eq!(params.get("prompt").unwrap(), "consent");
1123    }
1124
1125    #[test]
1126    fn test_custom_state_is_preserved() {
1127        let config = OAuthConfig {
1128            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1129            token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1130            client_id: "test-client-id".to_string(),
1131            client_secret: None,
1132            redirect_uri: "http://127.0.0.1:4317".to_string(),
1133            scopes: vec!["scope1".to_string()],
1134            audience: None,
1135            additional_params: HashMap::new(),
1136        };
1137
1138        let http = Arc::new(Client::new());
1139        let flow = OAuthFlow::new(config, http).unwrap();
1140
1141        let custom_state = "my-custom-state-value";
1142        let context = flow.build_authorize_url(&AuthorizeParams {
1143            state: Some(custom_state.to_string()),
1144            code_challenge: None,
1145            code_challenge_method: None,
1146        });
1147
1148        assert_eq!(context.state, custom_state);
1149        assert!(context.url.contains(&format!("state={}", custom_state)));
1150    }
1151}