nblm_core/auth/oauth/
mod.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use oauth2::{
9    basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointSet,
10    PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse,
11    TokenResponse as OAuth2TokenResponse, TokenUrl,
12};
13use parking_lot::RwLock;
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use time::OffsetDateTime;
17
18use crate::auth::{ProviderKind, TokenProvider};
19use crate::env::ApiProfile;
20use crate::error::{Error as CoreError, Result as CoreResult};
21
22mod config;
23mod error;
24pub mod loopback;
25#[cfg(test)]
26pub mod testing;
27
28pub use config::OAuthClientConfig;
29pub use error::{OAuthError, Result};
30
31impl From<OAuthError> for CoreError {
32    fn from(err: OAuthError) -> Self {
33        CoreError::TokenProvider(err.to_string())
34    }
35}
36
37// ============================================================================
38// OAuth Configuration
39// ============================================================================
40
41/// OAuth2 configuration for Authorization Code Flow with PKCE
42#[derive(Debug, Clone)]
43pub struct OAuthConfig {
44    pub auth_endpoint: String,
45    pub token_endpoint: String,
46    pub client_id: String,
47    pub client_secret: Option<String>,
48    pub redirect_uri: String,
49    pub scopes: Vec<String>,
50    pub audience: Option<String>,
51    pub additional_params: HashMap<String, String>,
52}
53
54const GOOGLE_REVOKE_ENDPOINT: &str = "https://oauth2.googleapis.com/revoke";
55const REVOKE_ENDPOINT_ENV: &str = "NBLM_OAUTH_REVOKE_ENDPOINT";
56
57impl OAuthConfig {
58    pub const DEFAULT_REDIRECT_URI: &str = "http://127.0.0.1:4317";
59    pub(crate) const AUTH_ENDPOINT: &str = "https://accounts.google.com/o/oauth2/v2/auth";
60    pub(crate) const TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
61    pub(crate) const SCOPE_CLOUD_PLATFORM: &str = "https://www.googleapis.com/auth/cloud-platform";
62    pub(crate) const SCOPE_DRIVE_FILE: &str = "https://www.googleapis.com/auth/drive.file";
63
64    /// Create a default Google OAuth2 configuration for NotebookLM Enterprise
65    pub fn google_default(_project_number: &str) -> Result<Self> {
66        OAuthClientConfig::from_env().map(|cfg| cfg.into_oauth_config())
67    }
68}
69
70// ============================================================================
71// OAuth Tokens
72// ============================================================================
73
74/// OAuth2 tokens returned from token endpoint
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct OAuthTokens {
77    pub access_token: String,
78    pub refresh_token: Option<String>,
79    pub expires_at: OffsetDateTime,
80    pub scope: Option<String>,
81    pub token_type: String,
82}
83
84impl OAuthTokens {
85    /// Create OAuthTokens from oauth2-rs token response
86    pub fn from_oauth2_response(
87        response: StandardTokenResponse<
88            oauth2::EmptyExtraTokenFields,
89            oauth2::basic::BasicTokenType,
90        >,
91        issued_at: OffsetDateTime,
92    ) -> Self {
93        let expires_at = issued_at
94            + response
95                .expires_in()
96                .map(|d| Duration::from_secs(d.as_secs()))
97                .unwrap_or_else(|| Duration::from_secs(3600));
98
99        let scope = response.scopes().map(|scopes| {
100            scopes
101                .iter()
102                .map(|s| s.as_str().to_string())
103                .collect::<Vec<_>>()
104                .join(" ")
105        });
106
107        Self {
108            access_token: response.access_token().secret().to_string(),
109            refresh_token: response.refresh_token().map(|rt| rt.secret().to_string()),
110            expires_at,
111            scope,
112            token_type: match response.token_type() {
113                oauth2::basic::BasicTokenType::Bearer => "Bearer".to_string(),
114                oauth2::basic::BasicTokenType::Mac => "MAC".to_string(),
115                oauth2::basic::BasicTokenType::Extension(s) => s.clone(),
116            },
117        }
118    }
119}
120
121// ============================================================================
122// Token Cache Entry
123// ============================================================================
124
125/// In-memory cache entry for OAuth tokens
126#[derive(Debug, Clone)]
127pub struct TokenCacheEntry {
128    pub tokens: OAuthTokens,
129    pub refresh_margin: Duration,
130}
131
132impl TokenCacheEntry {
133    pub fn new(tokens: OAuthTokens) -> Self {
134        Self {
135            tokens,
136            refresh_margin: Duration::from_secs(60), // Default 60 seconds
137        }
138    }
139
140    /// Check if token needs refresh
141    pub fn needs_refresh(&self, now: OffsetDateTime) -> bool {
142        now >= (self.tokens.expires_at - self.refresh_margin)
143    }
144}
145
146// ============================================================================
147// Token Store Key
148// ============================================================================
149
150/// Key for storing tokens in RefreshTokenStore
151#[derive(Debug, Clone, PartialEq, Eq, Hash)]
152pub struct TokenStoreKey {
153    pub profile: ApiProfile,
154    pub project_number: Option<String>,
155    pub endpoint_location: Option<String>,
156    pub user_hint: Option<String>,
157}
158
159impl fmt::Display for TokenStoreKey {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        let mut parts = vec![self.profile.as_str().to_string()];
162
163        if let Some(ref project) = self.project_number {
164            parts.push(format!("project={}", project));
165        }
166
167        if let Some(ref location) = self.endpoint_location {
168            parts.push(format!("location={}", location));
169        }
170
171        if let Some(ref user) = self.user_hint {
172            parts.push(format!("user={}", user));
173        }
174
175        write!(f, "{}", parts.join(":"))
176    }
177}
178
179// ============================================================================
180// Serialized Tokens (for storage)
181// ============================================================================
182
183/// Serialized token data for storage
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct SerializedTokens {
186    pub refresh_token: String,
187    pub scopes: Vec<String>,
188    pub expires_at: Option<OffsetDateTime>,
189    pub token_type: String,
190    #[serde(with = "time::serde::rfc3339")]
191    pub updated_at: OffsetDateTime,
192}
193
194/// Credentials file format
195#[derive(Debug, Serialize, Deserialize)]
196struct CredentialsFile {
197    version: u32,
198    entries: HashMap<String, SerializedTokens>,
199}
200
201impl CredentialsFile {
202    fn new() -> Self {
203        Self {
204            version: 1,
205            entries: HashMap::new(),
206        }
207    }
208}
209
210// ============================================================================
211// RefreshTokenStore Trait
212// ============================================================================
213
214/// Trait for storing and retrieving refresh tokens
215#[async_trait]
216pub trait RefreshTokenStore: Send + Sync {
217    /// Load tokens for the given key
218    async fn load(&self, key: &TokenStoreKey) -> Result<Option<SerializedTokens>>;
219
220    /// Save tokens for the given key
221    async fn save(&self, key: &TokenStoreKey, tokens: &SerializedTokens) -> Result<()>;
222
223    /// Delete tokens for the given key
224    async fn delete(&self, key: &TokenStoreKey) -> Result<()>;
225}
226
227// ============================================================================
228// FileRefreshTokenStore
229// ============================================================================
230
231const CONFIG_DIR_ENV: &str = "NBLM_CONFIG_DIR";
232
233/// File-based implementation of RefreshTokenStore
234pub struct FileRefreshTokenStore {
235    file_path: std::path::PathBuf,
236}
237
238impl FileRefreshTokenStore {
239    /// Create a new FileRefreshTokenStore
240    pub fn new() -> Result<Self> {
241        if let Ok(custom_dir) = std::env::var(CONFIG_DIR_ENV) {
242            let file_path = PathBuf::from(custom_dir).join("credentials.json");
243            return Self::from_path(file_path);
244        }
245
246        let dirs = directories::ProjectDirs::from("com", "nblm", "nblm-rs")
247            .ok_or_else(|| OAuthError::Config("failed to find config directory".to_string()))?;
248
249        let config_dir = dirs.config_dir();
250        let file_path = config_dir.join("credentials.json");
251
252        Self::from_path(file_path)
253    }
254
255    /// Create a store backed by an explicit credentials file path.
256    pub fn from_path(path: impl Into<PathBuf>) -> Result<Self> {
257        Ok(Self {
258            file_path: path.into(),
259        })
260    }
261
262    /// Ensure config directory exists with proper permissions (async)
263    async fn ensure_config_dir(&self) -> Result<()> {
264        if let Some(config_dir) = self.file_path.parent() {
265            tokio::fs::create_dir_all(config_dir).await.map_err(|e| {
266                OAuthError::Config(format!("failed to create config directory: {}", e))
267            })?;
268
269            #[cfg(unix)]
270            {
271                use std::os::unix::fs::PermissionsExt;
272                let mut perms = tokio::fs::metadata(config_dir)
273                    .await
274                    .map_err(|e| {
275                        OAuthError::Config(format!("failed to get config dir metadata: {}", e))
276                    })?
277                    .permissions();
278                perms.set_mode(0o700);
279                tokio::fs::set_permissions(config_dir, perms)
280                    .await
281                    .map_err(|e| {
282                        OAuthError::Config(format!("failed to set config dir permissions: {}", e))
283                    })?;
284            }
285        }
286        Ok(())
287    }
288
289    /// Load credentials file
290    async fn load_file(&self) -> Result<CredentialsFile> {
291        self.ensure_config_dir().await?;
292
293        if !self.file_path.exists() {
294            return Ok(CredentialsFile::new());
295        }
296
297        let content = tokio::fs::read_to_string(&self.file_path)
298            .await
299            .map_err(|e| OAuthError::Config(format!("failed to read credentials file: {}", e)))?;
300
301        let file: CredentialsFile = serde_json::from_str(&content)
302            .map_err(|e| OAuthError::Config(format!("failed to parse credentials file: {}", e)))?;
303
304        Ok(file)
305    }
306
307    /// Save credentials file
308    async fn save_file(&self, file: &CredentialsFile) -> Result<()> {
309        self.ensure_config_dir().await?;
310
311        let content = serde_json::to_string_pretty(file)
312            .map_err(|e| OAuthError::Config(format!("failed to serialize credentials: {}", e)))?;
313
314        // Write to temp file first, then rename (atomic write)
315        // Use a unique temporary file name to avoid conflicts in concurrent writes
316        let random_suffix = {
317            use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
318            use rand::RngCore;
319            let mut rng = rand::rng();
320            let mut random_bytes = [0u8; 8];
321            rng.fill_bytes(&mut random_bytes);
322            URL_SAFE_NO_PAD.encode(random_bytes)
323        };
324
325        let temp_path = self.file_path.with_file_name(format!(
326            "{}.{}.tmp",
327            self.file_path
328                .file_name()
329                .and_then(|n| n.to_str())
330                .unwrap_or("credentials.json"),
331            random_suffix
332        ));
333        tokio::fs::write(&temp_path, content)
334            .await
335            .map_err(|e| OAuthError::Config(format!("failed to write credentials file: {}", e)))?;
336
337        #[cfg(unix)]
338        {
339            use std::os::unix::fs::PermissionsExt;
340            if let Ok(metadata) = tokio::fs::metadata(&temp_path).await {
341                let mut perms = metadata.permissions();
342                perms.set_mode(0o600);
343                let _ = tokio::fs::set_permissions(&temp_path, perms).await;
344            }
345        }
346
347        tokio::fs::rename(&temp_path, &self.file_path)
348            .await
349            .map_err(|e| OAuthError::Config(format!("failed to rename temp file: {}", e)))?;
350
351        Ok(())
352    }
353
354    /// Remove the entire credentials file from disk (best-effort).
355    pub async fn delete_file(&self) -> Result<()> {
356        match tokio::fs::remove_file(&self.file_path).await {
357            Ok(_) => Ok(()),
358            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
359            Err(err) => Err(OAuthError::Storage(err)),
360        }
361    }
362}
363
364#[async_trait]
365impl RefreshTokenStore for FileRefreshTokenStore {
366    async fn load(&self, key: &TokenStoreKey) -> Result<Option<SerializedTokens>> {
367        let file = self.load_file().await?;
368        Ok(file.entries.get(&key.to_string()).cloned())
369    }
370
371    async fn save(&self, key: &TokenStoreKey, tokens: &SerializedTokens) -> Result<()> {
372        let mut file = self.load_file().await?;
373        file.entries.insert(key.to_string(), tokens.clone());
374        self.save_file(&file).await
375    }
376
377    async fn delete(&self, key: &TokenStoreKey) -> Result<()> {
378        let mut file = self.load_file().await?;
379        file.entries.remove(&key.to_string());
380        self.save_file(&file).await
381    }
382}
383
384// ============================================================================
385// OAuth Flow
386// ============================================================================
387
388/// Parameters for building authorization URL
389#[derive(Debug, Clone)]
390pub struct AuthorizeParams {
391    pub state: Option<String>,
392    pub code_challenge: Option<String>,
393    pub code_challenge_method: Option<String>,
394}
395
396/// Context for authorization flow
397#[derive(Debug, Clone)]
398pub struct AuthorizeContext {
399    pub url: String,
400    pub state: String,
401    pub code_verifier: String,
402    pub expires_at: OffsetDateTime,
403}
404
405/// OAuth2 Authorization Code Flow with PKCE
406pub struct OAuthFlow {
407    client: BasicClient<
408        EndpointSet,
409        oauth2::EndpointNotSet,
410        oauth2::EndpointNotSet,
411        oauth2::EndpointNotSet,
412        EndpointSet,
413    >,
414    config: OAuthConfig,
415    http: Arc<Client>,
416}
417
418impl OAuthFlow {
419    /// Create a new OAuthFlow
420    pub fn new(config: OAuthConfig, http: Arc<Client>) -> Result<Self> {
421        let client_id = ClientId::new(config.client_id.clone());
422        let auth_url = AuthUrl::new(config.auth_endpoint.clone())
423            .map_err(|e| OAuthError::Config(format!("invalid auth_url: {}", e)))?;
424        let token_url = TokenUrl::new(config.token_endpoint.clone())
425            .map_err(|e| OAuthError::Config(format!("invalid token_url: {}", e)))?;
426        let redirect_url = RedirectUrl::new(config.redirect_uri.clone())
427            .map_err(|e| OAuthError::Config(format!("invalid redirect_url: {}", e)))?;
428
429        let mut client_builder = BasicClient::new(client_id)
430            .set_auth_uri(auth_url)
431            .set_token_uri(token_url)
432            .set_redirect_uri(redirect_url);
433
434        if let Some(ref client_secret) = config.client_secret {
435            client_builder =
436                client_builder.set_client_secret(ClientSecret::new(client_secret.clone()));
437        }
438
439        Ok(Self {
440            client: client_builder,
441            config,
442            http,
443        })
444    }
445
446    /// Build authorization URL with PKCE
447    pub fn build_authorize_url(&self, params: &AuthorizeParams) -> AuthorizeContext {
448        // Generate PKCE challenge and verifier
449        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
450
451        // Generate CSRF token
452        let csrf_token = if let Some(ref state) = params.state {
453            CsrfToken::new(state.clone())
454        } else {
455            CsrfToken::new_random()
456        };
457
458        // Build authorization URL
459        let mut auth_request = self.client.authorize_url(|| csrf_token.clone());
460
461        // Add scopes
462        for scope_str in &self.config.scopes {
463            auth_request = auth_request.add_scope(Scope::new(scope_str.clone()));
464        }
465
466        // Set PKCE challenge
467        auth_request = auth_request.set_pkce_challenge(pkce_challenge);
468
469        // Add additional parameters
470        for (key, value) in &self.config.additional_params {
471            auth_request = auth_request.add_extra_param(key, value);
472        }
473
474        // Build the URL
475        let (auth_url, csrf_token_actual) = auth_request.url();
476
477        // Add Google-specific parameters
478        let mut url = url::Url::parse(auth_url.as_str()).expect("invalid auth_url");
479        url.query_pairs_mut()
480            .append_pair("access_type", "offline")
481            .append_pair("prompt", "consent");
482
483        if let Some(ref audience) = self.config.audience {
484            url.query_pairs_mut().append_pair("audience", audience);
485        }
486
487        let expires_at = OffsetDateTime::now_utc() + Duration::from_secs(600); // 10 minutes
488
489        AuthorizeContext {
490            url: url.to_string(),
491            state: csrf_token_actual.secret().to_string(),
492            code_verifier: pkce_verifier.secret().to_string(),
493            expires_at,
494        }
495    }
496
497    /// Exchange authorization code for tokens
498    pub async fn exchange_code(
499        &self,
500        context: &AuthorizeContext,
501        code: &str,
502    ) -> Result<OAuthTokens> {
503        let code = AuthorizationCode::new(code.to_string());
504        let pkce_verifier = PkceCodeVerifier::new(context.code_verifier.clone());
505
506        let token_request = self
507            .client
508            .exchange_code(code)
509            .set_pkce_verifier(pkce_verifier);
510
511        let token_response = token_request
512            .request_async(self.http.as_ref())
513            .await
514            .map_err(|e| OAuthError::Config(format!("oauth token exchange failed: {}", e)))?;
515
516        Ok(OAuthTokens::from_oauth2_response(
517            token_response,
518            OffsetDateTime::now_utc(),
519        ))
520    }
521
522    /// Refresh access token using refresh token
523    pub async fn refresh(&self, refresh_token: &str) -> Result<OAuthTokens> {
524        let refresh_token = RefreshToken::new(refresh_token.to_string());
525
526        let token_request = self.client.exchange_refresh_token(&refresh_token);
527
528        let token_response = token_request
529            .request_async(self.http.as_ref())
530            .await
531            .map_err(|e| OAuthError::Config(format!("oauth token refresh failed: {}", e)))?;
532
533        Ok(OAuthTokens::from_oauth2_response(
534            token_response,
535            OffsetDateTime::now_utc(),
536        ))
537    }
538
539    /// Revoke a refresh token via Google's revocation endpoint.
540    pub async fn revoke_refresh_token(&self, refresh_token: &str) -> Result<()> {
541        if refresh_token.trim().is_empty() {
542            return Err(OAuthError::Revocation(
543                "refresh token cannot be empty".into(),
544            ));
545        }
546
547        let endpoint = std::env::var(REVOKE_ENDPOINT_ENV)
548            .unwrap_or_else(|_| GOOGLE_REVOKE_ENDPOINT.to_string());
549
550        let response = self
551            .http
552            .post(&endpoint)
553            .form(&[("token", refresh_token)])
554            .send()
555            .await
556            .map_err(|e| OAuthError::Revocation(format!("revocation request failed: {}", e)))?;
557
558        if response.status().is_success() {
559            Ok(())
560        } else {
561            let status = response.status();
562            let body = response.text().await.unwrap_or_default();
563            Err(OAuthError::Revocation(format!(
564                "revocation failed (status {}): {}",
565                status, body
566            )))
567        }
568    }
569}
570
571// ============================================================================
572// RefreshTokenProvider
573// ============================================================================
574
575/// TokenProvider implementation using refresh tokens
576pub struct RefreshTokenProvider<S: RefreshTokenStore> {
577    flow: OAuthFlow,
578    store: Arc<S>,
579    cache: RwLock<Option<TokenCacheEntry>>,
580    store_key: TokenStoreKey,
581}
582
583impl<S: RefreshTokenStore> RefreshTokenProvider<S> {
584    /// Create a new RefreshTokenProvider
585    pub fn new(flow: OAuthFlow, store: Arc<S>, store_key: TokenStoreKey) -> Self {
586        Self {
587            flow,
588            store,
589            cache: RwLock::new(None),
590            store_key,
591        }
592    }
593
594    /// Ensure tokens are valid, refreshing if necessary
595    async fn ensure_tokens(&self, force_refresh: bool) -> Result<OAuthTokens> {
596        let now = OffsetDateTime::now_utc();
597
598        // Check cache first
599        if !force_refresh {
600            if let Some(ref entry) = *self.cache.read() {
601                if !entry.needs_refresh(now) {
602                    return Ok(entry.tokens.clone());
603                }
604            }
605        }
606
607        // Load refresh token from store
608        let stored = self
609            .store
610            .load(&self.store_key)
611            .await?
612            .ok_or_else(|| OAuthError::Config("refresh token unavailable".to_string()))?;
613
614        // Refresh access token
615        let tokens = self.flow.refresh(&stored.refresh_token).await?;
616        let refresh_token = tokens
617            .refresh_token
618            .clone()
619            .unwrap_or_else(|| stored.refresh_token.clone());
620        let scopes: Vec<String> = if let Some(scopes_str) = tokens.scope.as_ref() {
621            scopes_str.split_whitespace().map(String::from).collect()
622        } else if !stored.scopes.is_empty() {
623            stored.scopes.clone()
624        } else {
625            Vec::new()
626        };
627        let token_type = if !tokens.token_type.is_empty() {
628            tokens.token_type.clone()
629        } else {
630            stored.token_type.clone()
631        };
632
633        // Update cache
634        {
635            let mut cache = self.cache.write();
636            *cache = Some(TokenCacheEntry::new(tokens.clone()));
637        }
638
639        // Persist refresh token details (preserve previous token when refresh response omits it)
640        let serialized = SerializedTokens {
641            refresh_token,
642            scopes,
643            expires_at: Some(tokens.expires_at),
644            token_type,
645            updated_at: now,
646        };
647        self.store.save(&self.store_key, &serialized).await?;
648
649        Ok(tokens)
650    }
651}
652
653#[async_trait]
654impl<S: RefreshTokenStore> TokenProvider for RefreshTokenProvider<S> {
655    async fn access_token(&self) -> CoreResult<String> {
656        let tokens = self.ensure_tokens(false).await.map_err(CoreError::from)?;
657        Ok(tokens.access_token)
658    }
659
660    async fn refresh_token(&self) -> CoreResult<String> {
661        let tokens = self.ensure_tokens(true).await.map_err(CoreError::from)?;
662        Ok(tokens.access_token)
663    }
664
665    fn kind(&self) -> ProviderKind {
666        ProviderKind::UserOauth
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673    use crate::auth::oauth::testing::fake::FakeOAuthServer;
674    use tempfile::tempdir;
675    use wiremock::matchers::{method, path};
676    use wiremock::{Mock, MockServer, ResponseTemplate};
677
678    #[tokio::test]
679    async fn test_token_cache_entry_needs_refresh() {
680        let now = OffsetDateTime::now_utc();
681        let expires_at = now + Duration::from_secs(120); // 2 minutes
682        let tokens = OAuthTokens {
683            access_token: "test-token".to_string(),
684            refresh_token: None,
685            expires_at,
686            scope: None,
687            token_type: "Bearer".to_string(),
688        };
689
690        let entry = TokenCacheEntry::new(tokens);
691
692        // Should not need refresh immediately
693        assert!(!entry.needs_refresh(now));
694
695        // Should need refresh when close to expiry (within margin)
696        let near_expiry = expires_at - Duration::from_secs(30);
697        assert!(entry.needs_refresh(near_expiry));
698    }
699
700    #[tokio::test]
701    async fn test_token_store_key_display() {
702        let key = TokenStoreKey {
703            profile: ApiProfile::Enterprise,
704            project_number: Some("123456".to_string()),
705            endpoint_location: Some("global".to_string()),
706            user_hint: None,
707        };
708
709        let display = key.to_string();
710        assert!(display.contains("enterprise"));
711        assert!(display.contains("project=123456"));
712        assert!(display.contains("location=global"));
713    }
714
715    #[tokio::test]
716    async fn test_oauth_tokens_from_oauth2_response() {
717        use oauth2::StandardTokenResponse;
718
719        let now = OffsetDateTime::now_utc();
720        // Manually create a response with all fields using serde_json
721        let json_response = serde_json::json!({
722            "access_token": "access-token-123",
723            "refresh_token": "refresh-token-456",
724            "expires_in": 3600,
725            "token_type": "Bearer",
726            "scope": "scope1 scope2"
727        });
728
729        let response: StandardTokenResponse<
730            oauth2::EmptyExtraTokenFields,
731            oauth2::basic::BasicTokenType,
732        > = serde_json::from_value(json_response).unwrap();
733
734        let tokens = OAuthTokens::from_oauth2_response(response, now);
735        assert_eq!(tokens.access_token, "access-token-123");
736        assert_eq!(tokens.refresh_token, Some("refresh-token-456".to_string()));
737        assert_eq!(tokens.scope, Some("scope1 scope2".to_string()));
738        assert_eq!(tokens.token_type, "Bearer");
739        assert!(tokens.expires_at > now);
740    }
741
742    #[tokio::test]
743    async fn test_oauth_flow_refresh_token() {
744        let server = MockServer::start().await;
745
746        Mock::given(method("POST"))
747            .and(path("/token"))
748            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
749                "access_token": "new-access-token",
750                "expires_in": 3600,
751                "token_type": "Bearer"
752            })))
753            .mount(&server)
754            .await;
755
756        let config = OAuthConfig {
757            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
758            token_endpoint: format!("{}/token", server.uri()),
759            client_id: "test-client-id".to_string(),
760            client_secret: None,
761            redirect_uri: "http://127.0.0.1:4317".to_string(),
762            scopes: vec!["scope1".to_string()],
763            audience: None,
764            additional_params: HashMap::new(),
765        };
766
767        let http = Arc::new(Client::new());
768        let flow = OAuthFlow::new(config, http).unwrap();
769
770        let tokens = flow.refresh("refresh-token-123").await.unwrap();
771        assert_eq!(tokens.access_token, "new-access-token");
772    }
773
774    // Test 1: PKCE validity tests (security critical)
775    #[test]
776    fn test_pkce_code_verifier_and_challenge_generation() {
777        let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
778
779        // Verify code_verifier is base64url encoded and proper length
780        let verifier_str = verifier.secret();
781        assert!(!verifier_str.is_empty());
782        assert!(verifier_str.len() >= 43 && verifier_str.len() <= 128);
783        assert!(verifier_str
784            .chars()
785            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
786
787        // Verify code_challenge is base64url encoded SHA256 hash
788        let challenge_str = challenge.as_str();
789        assert!(!challenge_str.is_empty());
790        assert_eq!(challenge_str.len(), 43); // SHA256 hash is 32 bytes, base64url is 43 chars
791        assert!(challenge_str
792            .chars()
793            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
794    }
795
796    #[test]
797    fn test_pkce_generates_unique_values() {
798        let (challenge1, verifier1) = PkceCodeChallenge::new_random_sha256();
799        let (challenge2, verifier2) = PkceCodeChallenge::new_random_sha256();
800
801        // Each generation should produce unique values
802        assert_ne!(verifier1.secret(), verifier2.secret());
803        assert_ne!(challenge1.as_str(), challenge2.as_str());
804    }
805
806    #[test]
807    fn test_state_generation_uniqueness() {
808        let state1 = CsrfToken::new_random();
809        let state2 = CsrfToken::new_random();
810
811        // States should be unique
812        assert_ne!(state1.secret(), state2.secret());
813        assert!(!state1.secret().is_empty());
814        assert!(!state2.secret().is_empty());
815    }
816
817    // Test 2: Token expiration boundary tests
818    #[test]
819    fn test_token_needs_refresh_at_exact_expiry() {
820        let now = OffsetDateTime::now_utc();
821        let expires_at = now + Duration::from_secs(60);
822        let tokens = OAuthTokens {
823            access_token: "test-token".to_string(),
824            refresh_token: None,
825            expires_at,
826            scope: None,
827            token_type: "Bearer".to_string(),
828        };
829
830        let entry = TokenCacheEntry::new(tokens);
831
832        // At exact expiry time (accounting for margin)
833        assert!(entry.needs_refresh(expires_at));
834    }
835
836    #[test]
837    fn test_token_needs_refresh_just_before_margin() {
838        let now = OffsetDateTime::now_utc();
839        let expires_at = now + Duration::from_secs(120);
840        let tokens = OAuthTokens {
841            access_token: "test-token".to_string(),
842            refresh_token: None,
843            expires_at,
844            scope: None,
845            token_type: "Bearer".to_string(),
846        };
847
848        let entry = TokenCacheEntry::new(tokens);
849
850        // Just before margin (61 seconds before expiry, margin is 60)
851        let before_margin = expires_at - Duration::from_secs(61);
852        assert!(!entry.needs_refresh(before_margin));
853
854        // Just at margin (60 seconds before expiry)
855        let at_margin = expires_at - Duration::from_secs(60);
856        assert!(entry.needs_refresh(at_margin));
857    }
858
859    #[test]
860    fn test_token_needs_refresh_after_expiry() {
861        let now = OffsetDateTime::now_utc();
862        let expires_at = now - Duration::from_secs(10); // Already expired
863        let tokens = OAuthTokens {
864            access_token: "test-token".to_string(),
865            refresh_token: None,
866            expires_at,
867            scope: None,
868            token_type: "Bearer".to_string(),
869        };
870
871        let entry = TokenCacheEntry::new(tokens);
872        assert!(entry.needs_refresh(now));
873    }
874
875    // Test 3: File store concurrent access tests
876    #[tokio::test]
877    #[serial_test::serial]
878    async fn test_file_store_concurrent_saves() {
879        use std::sync::Arc;
880        use tokio::task::JoinSet;
881
882        let temp_dir = tempdir().unwrap();
883        let store_path = temp_dir.path().join("credentials.json");
884        let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
885        let key = TokenStoreKey {
886            profile: ApiProfile::Enterprise,
887            project_number: Some("concurrent-test".to_string()),
888            endpoint_location: Some("global".to_string()),
889            user_hint: None,
890        };
891
892        let mut join_set = JoinSet::new();
893
894        // Spawn 10 concurrent save operations
895        for i in 0..10 {
896            let store_clone = Arc::clone(&store);
897            let key_clone = key.clone();
898            join_set.spawn(async move {
899                let tokens = SerializedTokens {
900                    refresh_token: format!("token-{}", i),
901                    scopes: vec!["scope1".to_string()],
902                    expires_at: Some(OffsetDateTime::now_utc() + Duration::from_secs(3600)),
903                    token_type: "Bearer".to_string(),
904                    updated_at: OffsetDateTime::now_utc(),
905                };
906                store_clone.save(&key_clone, &tokens).await
907            });
908        }
909
910        // Wait for all operations to complete
911        while let Some(result) = join_set.join_next().await {
912            result.unwrap().unwrap();
913        }
914
915        // Verify that the final state is consistent
916        let loaded = store.load(&key).await.unwrap();
917        assert!(loaded.is_some());
918        let loaded = loaded.unwrap();
919        assert!(loaded.refresh_token.starts_with("token-"));
920
921        // Cleanup
922        store.delete(&key).await.unwrap();
923    }
924
925    #[tokio::test]
926    #[serial_test::serial]
927    async fn test_file_store_atomic_write() {
928        let temp_dir = tempdir().unwrap();
929        let store_path = temp_dir.path().join("credentials.json");
930        let store = FileRefreshTokenStore::from_path(&store_path).unwrap();
931        let key = TokenStoreKey {
932            profile: ApiProfile::Enterprise,
933            project_number: Some("atomic-test".to_string()),
934            endpoint_location: Some("global".to_string()),
935            user_hint: None,
936        };
937
938        let tokens = SerializedTokens {
939            refresh_token: "initial-token".to_string(),
940            scopes: vec!["scope1".to_string()],
941            expires_at: Some(OffsetDateTime::now_utc() + Duration::from_secs(3600)),
942            token_type: "Bearer".to_string(),
943            updated_at: OffsetDateTime::now_utc(),
944        };
945
946        store.save(&key, &tokens).await.unwrap();
947
948        // Verify temp files are cleaned up after save
949        let entries: Vec<_> = std::fs::read_dir(temp_dir.path())
950            .unwrap()
951            .map(|entry| entry.unwrap().file_name())
952            .collect();
953        assert_eq!(entries.len(), 1);
954        assert_eq!(entries[0], std::ffi::OsStr::new("credentials.json"));
955
956        // Cleanup
957        store.delete(&key).await.unwrap();
958    }
959
960    #[test]
961    #[serial_test::serial]
962    fn test_file_store_respects_custom_config_dir() {
963        use std::env;
964
965        let temp_dir = tempdir().unwrap();
966        let key = CONFIG_DIR_ENV;
967        let original = env::var(key).ok();
968        env::set_var(key, temp_dir.path());
969
970        let store = FileRefreshTokenStore::new().expect("store should honor custom dir");
971        assert!(store.file_path.starts_with(temp_dir.path()));
972
973        if let Some(value) = original {
974            env::set_var(key, value);
975        } else {
976            env::remove_var(key);
977        }
978    }
979
980    // Test 5: refresh_token omission handling tests
981    #[tokio::test]
982    #[serial_test::serial]
983    async fn test_refresh_token_preserved_when_omitted_in_response() {
984        let server = MockServer::start().await;
985
986        // Mock refresh endpoint that doesn't return refresh_token
987        Mock::given(method("POST"))
988            .and(path("/token"))
989            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
990                "access_token": "new-access-token",
991                "expires_in": 3600,
992                "token_type": "Bearer"
993            })))
994            .mount(&server)
995            .await;
996
997        let config = OAuthConfig {
998            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
999            token_endpoint: format!("{}/token", server.uri()),
1000            client_id: "test-client-id".to_string(),
1001            client_secret: None,
1002            redirect_uri: "http://127.0.0.1:4317".to_string(),
1003            scopes: vec!["scope1".to_string()],
1004            audience: None,
1005            additional_params: HashMap::new(),
1006        };
1007
1008        let http = Arc::new(Client::new());
1009        let flow = OAuthFlow::new(config, http).unwrap();
1010        let temp_dir = tempdir().unwrap();
1011        let store_path = temp_dir.path().join("credentials.json");
1012        let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
1013        let key = TokenStoreKey {
1014            profile: ApiProfile::Enterprise,
1015            project_number: Some("omission-test".to_string()),
1016            endpoint_location: Some("global".to_string()),
1017            user_hint: None,
1018        };
1019
1020        // Store initial refresh token with expired access token
1021        let initial_tokens = SerializedTokens {
1022            refresh_token: "original-refresh-token".to_string(),
1023            scopes: vec!["scope1".to_string()],
1024            expires_at: Some(OffsetDateTime::now_utc() - Duration::from_secs(3600)), // Expired
1025            token_type: "Bearer".to_string(),
1026            updated_at: OffsetDateTime::now_utc() - Duration::from_secs(3600),
1027        };
1028        store.save(&key, &initial_tokens).await.unwrap();
1029
1030        let provider = RefreshTokenProvider::new(flow, Arc::clone(&store), key.clone());
1031
1032        // Get access token (should trigger refresh because token is expired)
1033        let access_token = provider.access_token().await.unwrap();
1034        assert_eq!(access_token, "new-access-token");
1035
1036        // Verify original refresh token is preserved
1037        let stored = store.load(&key).await.unwrap().unwrap();
1038        assert_eq!(stored.refresh_token, "original-refresh-token");
1039
1040        // Cleanup
1041        store.delete(&key).await.unwrap();
1042    }
1043
1044    #[tokio::test]
1045    #[serial_test::serial]
1046    async fn test_refresh_token_updated_when_included_in_response() {
1047        let server = MockServer::start().await;
1048
1049        // Mock refresh endpoint that returns new refresh_token
1050        Mock::given(method("POST"))
1051            .and(path("/token"))
1052            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1053                "access_token": "new-access-token",
1054                "refresh_token": "new-refresh-token",
1055                "expires_in": 3600,
1056                "token_type": "Bearer"
1057            })))
1058            .mount(&server)
1059            .await;
1060
1061        let config = OAuthConfig {
1062            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1063            token_endpoint: format!("{}/token", server.uri()),
1064            client_id: "test-client-id".to_string(),
1065            client_secret: None,
1066            redirect_uri: "http://127.0.0.1:4317".to_string(),
1067            scopes: vec!["scope1".to_string()],
1068            audience: None,
1069            additional_params: HashMap::new(),
1070        };
1071
1072        let http = Arc::new(Client::new());
1073        let flow = OAuthFlow::new(config, http).unwrap();
1074        let temp_dir = tempdir().unwrap();
1075        let store_path = temp_dir.path().join("credentials.json");
1076        let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
1077        let key = TokenStoreKey {
1078            profile: ApiProfile::Enterprise,
1079            project_number: Some("update-test".to_string()),
1080            endpoint_location: Some("global".to_string()),
1081            user_hint: None,
1082        };
1083
1084        // Store initial refresh token with expired access token
1085        let initial_tokens = SerializedTokens {
1086            refresh_token: "original-refresh-token".to_string(),
1087            scopes: vec!["scope1".to_string()],
1088            expires_at: Some(OffsetDateTime::now_utc() - Duration::from_secs(3600)), // Expired
1089            token_type: "Bearer".to_string(),
1090            updated_at: OffsetDateTime::now_utc() - Duration::from_secs(3600),
1091        };
1092        store.save(&key, &initial_tokens).await.unwrap();
1093
1094        let provider = RefreshTokenProvider::new(flow, Arc::clone(&store), key.clone());
1095
1096        // Get access token (should trigger refresh because token is expired)
1097        let access_token = provider.access_token().await.unwrap();
1098        assert_eq!(access_token, "new-access-token");
1099
1100        // Verify refresh token is updated
1101        let stored = store.load(&key).await.unwrap().unwrap();
1102        assert_eq!(stored.refresh_token, "new-refresh-token");
1103
1104        // Cleanup
1105        store.delete(&key).await.unwrap();
1106    }
1107
1108    // Test 6: State validation failure tests (CSRF)
1109    #[tokio::test]
1110    async fn test_state_mismatch_detection() {
1111        let config = OAuthConfig {
1112            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1113            token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1114            client_id: "test-client-id".to_string(),
1115            client_secret: None,
1116            redirect_uri: "http://127.0.0.1:4317".to_string(),
1117            scopes: vec!["scope1".to_string()],
1118            audience: None,
1119            additional_params: HashMap::new(),
1120        };
1121
1122        let http = Arc::new(Client::new());
1123        let flow = OAuthFlow::new(config, http).unwrap();
1124
1125        let context = flow.build_authorize_url(&AuthorizeParams {
1126            state: None,
1127            code_challenge: None,
1128            code_challenge_method: None,
1129        });
1130
1131        // Simulate state mismatch (CSRF attack)
1132        let wrong_state = "attacker-controlled-state";
1133        assert_ne!(context.state, wrong_state);
1134
1135        // In actual implementation, this would be caught by the callback handler
1136        // which compares the received state with context.state
1137    }
1138
1139    #[test]
1140    fn test_authorize_url_contains_required_parameters() {
1141        let config = OAuthConfig {
1142            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1143            token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1144            client_id: "test-client-id".to_string(),
1145            client_secret: None,
1146            redirect_uri: "http://127.0.0.1:4317".to_string(),
1147            scopes: vec!["scope1".to_string(), "scope2".to_string()],
1148            audience: None,
1149            additional_params: HashMap::new(),
1150        };
1151
1152        let http = Arc::new(Client::new());
1153        let flow = OAuthFlow::new(config, http).unwrap();
1154
1155        let context = flow.build_authorize_url(&AuthorizeParams {
1156            state: None,
1157            code_challenge: None,
1158            code_challenge_method: None,
1159        });
1160
1161        let url = url::Url::parse(&context.url).unwrap();
1162        let params: HashMap<_, _> = url.query_pairs().collect();
1163
1164        // Verify required PKCE and OAuth parameters
1165        assert!(params.contains_key("client_id"));
1166        assert!(params.contains_key("redirect_uri"));
1167        assert!(params.contains_key("response_type"));
1168        assert_eq!(params.get("response_type").unwrap(), "code");
1169        assert!(params.contains_key("scope"));
1170        assert!(params.contains_key("state"));
1171        assert!(params.contains_key("code_challenge"));
1172        assert!(params.contains_key("code_challenge_method"));
1173        assert_eq!(params.get("code_challenge_method").unwrap(), "S256");
1174        assert!(params.contains_key("access_type"));
1175        assert_eq!(params.get("access_type").unwrap(), "offline");
1176        assert!(params.contains_key("prompt"));
1177        assert_eq!(params.get("prompt").unwrap(), "consent");
1178    }
1179
1180    #[test]
1181    fn test_custom_state_is_preserved() {
1182        let config = OAuthConfig {
1183            auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1184            token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1185            client_id: "test-client-id".to_string(),
1186            client_secret: None,
1187            redirect_uri: "http://127.0.0.1:4317".to_string(),
1188            scopes: vec!["scope1".to_string()],
1189            audience: None,
1190            additional_params: HashMap::new(),
1191        };
1192
1193        let http = Arc::new(Client::new());
1194        let flow = OAuthFlow::new(config, http).unwrap();
1195
1196        let custom_state = "my-custom-state-value";
1197        let context = flow.build_authorize_url(&AuthorizeParams {
1198            state: Some(custom_state.to_string()),
1199            code_challenge: None,
1200            code_challenge_method: None,
1201        });
1202
1203        assert_eq!(context.state, custom_state);
1204        assert!(context.url.contains(&format!("state={}", custom_state)));
1205    }
1206
1207    #[tokio::test]
1208    #[serial_test::serial]
1209    async fn revoke_refresh_token_succeeds_with_fake_server() {
1210        let server = FakeOAuthServer::start().await;
1211        std::env::set_var(REVOKE_ENDPOINT_ENV, server.revoke_endpoint());
1212
1213        let config = OAuthConfig {
1214            auth_endpoint: format!("{}/auth", server.base_uri()),
1215            token_endpoint: server.token_endpoint(),
1216            client_id: "client-id".into(),
1217            client_secret: None,
1218            redirect_uri: "http://localhost:1234".into(),
1219            scopes: vec!["scope".into()],
1220            audience: None,
1221            additional_params: HashMap::new(),
1222        };
1223        let flow = OAuthFlow::new(config, Arc::new(Client::new())).unwrap();
1224        flow.revoke_refresh_token("fake_refresh_token")
1225            .await
1226            .unwrap();
1227        std::env::remove_var(REVOKE_ENDPOINT_ENV);
1228    }
1229
1230    #[tokio::test]
1231    #[serial_test::serial]
1232    async fn revoke_refresh_token_reports_error() {
1233        let server = FakeOAuthServer::start().await;
1234        std::env::set_var(
1235            REVOKE_ENDPOINT_ENV,
1236            format!("{}/invalid", server.base_uri()),
1237        );
1238
1239        let config = OAuthConfig {
1240            auth_endpoint: format!("{}/auth", server.base_uri()),
1241            token_endpoint: server.token_endpoint(),
1242            client_id: "client-id".into(),
1243            client_secret: None,
1244            redirect_uri: "http://localhost:1234".into(),
1245            scopes: vec!["scope".into()],
1246            audience: None,
1247            additional_params: HashMap::new(),
1248        };
1249        let flow = OAuthFlow::new(config, Arc::new(Client::new())).unwrap();
1250        let err = flow
1251            .revoke_refresh_token("fake_refresh_token")
1252            .await
1253            .unwrap_err();
1254        assert!(matches!(err, OAuthError::Revocation(_)));
1255        std::env::remove_var(REVOKE_ENDPOINT_ENV);
1256    }
1257
1258    #[tokio::test]
1259    async fn delete_file_removes_credentials_file() {
1260        let temp = tempdir().unwrap();
1261        let path = temp.path().join("credentials.json");
1262        let store = FileRefreshTokenStore::from_path(&path).unwrap();
1263        let key = TokenStoreKey {
1264            profile: ApiProfile::Enterprise,
1265            project_number: Some("123".into()),
1266            endpoint_location: Some("global".into()),
1267            user_hint: None,
1268        };
1269        let serialized = SerializedTokens {
1270            refresh_token: "token".into(),
1271            scopes: vec![],
1272            expires_at: None,
1273            token_type: "Bearer".into(),
1274            updated_at: OffsetDateTime::now_utc(),
1275        };
1276        store.save(&key, &serialized).await.unwrap();
1277        assert!(path.exists());
1278        store.delete_file().await.unwrap();
1279        assert!(!path.exists());
1280    }
1281}