1use std::collections::HashMap;
2use std::fmt;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use oauth2::{
9 basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointSet,
10 PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse,
11 TokenResponse as OAuth2TokenResponse, TokenUrl,
12};
13use parking_lot::RwLock;
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use time::OffsetDateTime;
17
18use crate::auth::{ProviderKind, TokenProvider};
19use crate::env::ApiProfile;
20use crate::error::{Error as CoreError, Result as CoreResult};
21
22mod config;
23mod error;
24pub mod loopback;
25#[cfg(test)]
26pub mod testing;
27
28pub use config::OAuthClientConfig;
29pub use error::{OAuthError, Result};
30
31impl From<OAuthError> for CoreError {
32 fn from(err: OAuthError) -> Self {
33 CoreError::TokenProvider(err.to_string())
34 }
35}
36
37#[derive(Debug, Clone)]
43pub struct OAuthConfig {
44 pub auth_endpoint: String,
45 pub token_endpoint: String,
46 pub client_id: String,
47 pub client_secret: Option<String>,
48 pub redirect_uri: String,
49 pub scopes: Vec<String>,
50 pub audience: Option<String>,
51 pub additional_params: HashMap<String, String>,
52}
53
54const GOOGLE_REVOKE_ENDPOINT: &str = "https://oauth2.googleapis.com/revoke";
55const REVOKE_ENDPOINT_ENV: &str = "NBLM_OAUTH_REVOKE_ENDPOINT";
56
57impl OAuthConfig {
58 pub const DEFAULT_REDIRECT_URI: &str = "http://127.0.0.1:4317";
59 pub(crate) const AUTH_ENDPOINT: &str = "https://accounts.google.com/o/oauth2/v2/auth";
60 pub(crate) const TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
61 pub(crate) const SCOPE_CLOUD_PLATFORM: &str = "https://www.googleapis.com/auth/cloud-platform";
62 pub(crate) const SCOPE_DRIVE_FILE: &str = "https://www.googleapis.com/auth/drive.file";
63
64 pub fn google_default(_project_number: &str) -> Result<Self> {
66 OAuthClientConfig::from_env().map(|cfg| cfg.into_oauth_config())
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct OAuthTokens {
77 pub access_token: String,
78 pub refresh_token: Option<String>,
79 pub expires_at: OffsetDateTime,
80 pub scope: Option<String>,
81 pub token_type: String,
82}
83
84impl OAuthTokens {
85 pub fn from_oauth2_response(
87 response: StandardTokenResponse<
88 oauth2::EmptyExtraTokenFields,
89 oauth2::basic::BasicTokenType,
90 >,
91 issued_at: OffsetDateTime,
92 ) -> Self {
93 let expires_at = issued_at
94 + response
95 .expires_in()
96 .map(|d| Duration::from_secs(d.as_secs()))
97 .unwrap_or_else(|| Duration::from_secs(3600));
98
99 let scope = response.scopes().map(|scopes| {
100 scopes
101 .iter()
102 .map(|s| s.as_str().to_string())
103 .collect::<Vec<_>>()
104 .join(" ")
105 });
106
107 Self {
108 access_token: response.access_token().secret().to_string(),
109 refresh_token: response.refresh_token().map(|rt| rt.secret().to_string()),
110 expires_at,
111 scope,
112 token_type: match response.token_type() {
113 oauth2::basic::BasicTokenType::Bearer => "Bearer".to_string(),
114 oauth2::basic::BasicTokenType::Mac => "MAC".to_string(),
115 oauth2::basic::BasicTokenType::Extension(s) => s.clone(),
116 },
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
127pub struct TokenCacheEntry {
128 pub tokens: OAuthTokens,
129 pub refresh_margin: Duration,
130}
131
132impl TokenCacheEntry {
133 pub fn new(tokens: OAuthTokens) -> Self {
134 Self {
135 tokens,
136 refresh_margin: Duration::from_secs(60), }
138 }
139
140 pub fn needs_refresh(&self, now: OffsetDateTime) -> bool {
142 now >= (self.tokens.expires_at - self.refresh_margin)
143 }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Hash)]
152pub struct TokenStoreKey {
153 pub profile: ApiProfile,
154 pub project_number: Option<String>,
155 pub endpoint_location: Option<String>,
156 pub user_hint: Option<String>,
157}
158
159impl fmt::Display for TokenStoreKey {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 let mut parts = vec![self.profile.as_str().to_string()];
162
163 if let Some(ref project) = self.project_number {
164 parts.push(format!("project={}", project));
165 }
166
167 if let Some(ref location) = self.endpoint_location {
168 parts.push(format!("location={}", location));
169 }
170
171 if let Some(ref user) = self.user_hint {
172 parts.push(format!("user={}", user));
173 }
174
175 write!(f, "{}", parts.join(":"))
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct SerializedTokens {
186 pub refresh_token: String,
187 pub scopes: Vec<String>,
188 pub expires_at: Option<OffsetDateTime>,
189 pub token_type: String,
190 #[serde(with = "time::serde::rfc3339")]
191 pub updated_at: OffsetDateTime,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
196struct CredentialsFile {
197 version: u32,
198 entries: HashMap<String, SerializedTokens>,
199}
200
201impl CredentialsFile {
202 fn new() -> Self {
203 Self {
204 version: 1,
205 entries: HashMap::new(),
206 }
207 }
208}
209
210#[async_trait]
216pub trait RefreshTokenStore: Send + Sync {
217 async fn load(&self, key: &TokenStoreKey) -> Result<Option<SerializedTokens>>;
219
220 async fn save(&self, key: &TokenStoreKey, tokens: &SerializedTokens) -> Result<()>;
222
223 async fn delete(&self, key: &TokenStoreKey) -> Result<()>;
225}
226
227const CONFIG_DIR_ENV: &str = "NBLM_CONFIG_DIR";
232
233pub struct FileRefreshTokenStore {
235 file_path: std::path::PathBuf,
236}
237
238impl FileRefreshTokenStore {
239 pub fn new() -> Result<Self> {
241 if let Ok(custom_dir) = std::env::var(CONFIG_DIR_ENV) {
242 let file_path = PathBuf::from(custom_dir).join("credentials.json");
243 return Self::from_path(file_path);
244 }
245
246 let dirs = directories::ProjectDirs::from("com", "nblm", "nblm-rs")
247 .ok_or_else(|| OAuthError::Config("failed to find config directory".to_string()))?;
248
249 let config_dir = dirs.config_dir();
250 let file_path = config_dir.join("credentials.json");
251
252 Self::from_path(file_path)
253 }
254
255 pub fn from_path(path: impl Into<PathBuf>) -> Result<Self> {
257 Ok(Self {
258 file_path: path.into(),
259 })
260 }
261
262 async fn ensure_config_dir(&self) -> Result<()> {
264 if let Some(config_dir) = self.file_path.parent() {
265 tokio::fs::create_dir_all(config_dir).await.map_err(|e| {
266 OAuthError::Config(format!("failed to create config directory: {}", e))
267 })?;
268
269 #[cfg(unix)]
270 {
271 use std::os::unix::fs::PermissionsExt;
272 let mut perms = tokio::fs::metadata(config_dir)
273 .await
274 .map_err(|e| {
275 OAuthError::Config(format!("failed to get config dir metadata: {}", e))
276 })?
277 .permissions();
278 perms.set_mode(0o700);
279 tokio::fs::set_permissions(config_dir, perms)
280 .await
281 .map_err(|e| {
282 OAuthError::Config(format!("failed to set config dir permissions: {}", e))
283 })?;
284 }
285 }
286 Ok(())
287 }
288
289 async fn load_file(&self) -> Result<CredentialsFile> {
291 self.ensure_config_dir().await?;
292
293 if !self.file_path.exists() {
294 return Ok(CredentialsFile::new());
295 }
296
297 let content = tokio::fs::read_to_string(&self.file_path)
298 .await
299 .map_err(|e| OAuthError::Config(format!("failed to read credentials file: {}", e)))?;
300
301 let file: CredentialsFile = serde_json::from_str(&content)
302 .map_err(|e| OAuthError::Config(format!("failed to parse credentials file: {}", e)))?;
303
304 Ok(file)
305 }
306
307 async fn save_file(&self, file: &CredentialsFile) -> Result<()> {
309 self.ensure_config_dir().await?;
310
311 let content = serde_json::to_string_pretty(file)
312 .map_err(|e| OAuthError::Config(format!("failed to serialize credentials: {}", e)))?;
313
314 let random_suffix = {
317 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
318 use rand::RngCore;
319 let mut rng = rand::rng();
320 let mut random_bytes = [0u8; 8];
321 rng.fill_bytes(&mut random_bytes);
322 URL_SAFE_NO_PAD.encode(random_bytes)
323 };
324
325 let temp_path = self.file_path.with_file_name(format!(
326 "{}.{}.tmp",
327 self.file_path
328 .file_name()
329 .and_then(|n| n.to_str())
330 .unwrap_or("credentials.json"),
331 random_suffix
332 ));
333 tokio::fs::write(&temp_path, content)
334 .await
335 .map_err(|e| OAuthError::Config(format!("failed to write credentials file: {}", e)))?;
336
337 #[cfg(unix)]
338 {
339 use std::os::unix::fs::PermissionsExt;
340 if let Ok(metadata) = tokio::fs::metadata(&temp_path).await {
341 let mut perms = metadata.permissions();
342 perms.set_mode(0o600);
343 let _ = tokio::fs::set_permissions(&temp_path, perms).await;
344 }
345 }
346
347 tokio::fs::rename(&temp_path, &self.file_path)
348 .await
349 .map_err(|e| OAuthError::Config(format!("failed to rename temp file: {}", e)))?;
350
351 Ok(())
352 }
353
354 pub async fn delete_file(&self) -> Result<()> {
356 match tokio::fs::remove_file(&self.file_path).await {
357 Ok(_) => Ok(()),
358 Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
359 Err(err) => Err(OAuthError::Storage(err)),
360 }
361 }
362}
363
364#[async_trait]
365impl RefreshTokenStore for FileRefreshTokenStore {
366 async fn load(&self, key: &TokenStoreKey) -> Result<Option<SerializedTokens>> {
367 let file = self.load_file().await?;
368 Ok(file.entries.get(&key.to_string()).cloned())
369 }
370
371 async fn save(&self, key: &TokenStoreKey, tokens: &SerializedTokens) -> Result<()> {
372 let mut file = self.load_file().await?;
373 file.entries.insert(key.to_string(), tokens.clone());
374 self.save_file(&file).await
375 }
376
377 async fn delete(&self, key: &TokenStoreKey) -> Result<()> {
378 let mut file = self.load_file().await?;
379 file.entries.remove(&key.to_string());
380 self.save_file(&file).await
381 }
382}
383
384#[derive(Debug, Clone)]
390pub struct AuthorizeParams {
391 pub state: Option<String>,
392 pub code_challenge: Option<String>,
393 pub code_challenge_method: Option<String>,
394}
395
396#[derive(Debug, Clone)]
398pub struct AuthorizeContext {
399 pub url: String,
400 pub state: String,
401 pub code_verifier: String,
402 pub expires_at: OffsetDateTime,
403}
404
405pub struct OAuthFlow {
407 client: BasicClient<
408 EndpointSet,
409 oauth2::EndpointNotSet,
410 oauth2::EndpointNotSet,
411 oauth2::EndpointNotSet,
412 EndpointSet,
413 >,
414 config: OAuthConfig,
415 http: Arc<Client>,
416}
417
418impl OAuthFlow {
419 pub fn new(config: OAuthConfig, http: Arc<Client>) -> Result<Self> {
421 let client_id = ClientId::new(config.client_id.clone());
422 let auth_url = AuthUrl::new(config.auth_endpoint.clone())
423 .map_err(|e| OAuthError::Config(format!("invalid auth_url: {}", e)))?;
424 let token_url = TokenUrl::new(config.token_endpoint.clone())
425 .map_err(|e| OAuthError::Config(format!("invalid token_url: {}", e)))?;
426 let redirect_url = RedirectUrl::new(config.redirect_uri.clone())
427 .map_err(|e| OAuthError::Config(format!("invalid redirect_url: {}", e)))?;
428
429 let mut client_builder = BasicClient::new(client_id)
430 .set_auth_uri(auth_url)
431 .set_token_uri(token_url)
432 .set_redirect_uri(redirect_url);
433
434 if let Some(ref client_secret) = config.client_secret {
435 client_builder =
436 client_builder.set_client_secret(ClientSecret::new(client_secret.clone()));
437 }
438
439 Ok(Self {
440 client: client_builder,
441 config,
442 http,
443 })
444 }
445
446 pub fn build_authorize_url(&self, params: &AuthorizeParams) -> AuthorizeContext {
448 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
450
451 let csrf_token = if let Some(ref state) = params.state {
453 CsrfToken::new(state.clone())
454 } else {
455 CsrfToken::new_random()
456 };
457
458 let mut auth_request = self.client.authorize_url(|| csrf_token.clone());
460
461 for scope_str in &self.config.scopes {
463 auth_request = auth_request.add_scope(Scope::new(scope_str.clone()));
464 }
465
466 auth_request = auth_request.set_pkce_challenge(pkce_challenge);
468
469 for (key, value) in &self.config.additional_params {
471 auth_request = auth_request.add_extra_param(key, value);
472 }
473
474 let (auth_url, csrf_token_actual) = auth_request.url();
476
477 let mut url = url::Url::parse(auth_url.as_str()).expect("invalid auth_url");
479 url.query_pairs_mut()
480 .append_pair("access_type", "offline")
481 .append_pair("prompt", "consent");
482
483 if let Some(ref audience) = self.config.audience {
484 url.query_pairs_mut().append_pair("audience", audience);
485 }
486
487 let expires_at = OffsetDateTime::now_utc() + Duration::from_secs(600); AuthorizeContext {
490 url: url.to_string(),
491 state: csrf_token_actual.secret().to_string(),
492 code_verifier: pkce_verifier.secret().to_string(),
493 expires_at,
494 }
495 }
496
497 pub async fn exchange_code(
499 &self,
500 context: &AuthorizeContext,
501 code: &str,
502 ) -> Result<OAuthTokens> {
503 let code = AuthorizationCode::new(code.to_string());
504 let pkce_verifier = PkceCodeVerifier::new(context.code_verifier.clone());
505
506 let token_request = self
507 .client
508 .exchange_code(code)
509 .set_pkce_verifier(pkce_verifier);
510
511 let token_response = token_request
512 .request_async(self.http.as_ref())
513 .await
514 .map_err(|e| OAuthError::Config(format!("oauth token exchange failed: {}", e)))?;
515
516 Ok(OAuthTokens::from_oauth2_response(
517 token_response,
518 OffsetDateTime::now_utc(),
519 ))
520 }
521
522 pub async fn refresh(&self, refresh_token: &str) -> Result<OAuthTokens> {
524 let refresh_token = RefreshToken::new(refresh_token.to_string());
525
526 let token_request = self.client.exchange_refresh_token(&refresh_token);
527
528 let token_response = token_request
529 .request_async(self.http.as_ref())
530 .await
531 .map_err(|e| OAuthError::Config(format!("oauth token refresh failed: {}", e)))?;
532
533 Ok(OAuthTokens::from_oauth2_response(
534 token_response,
535 OffsetDateTime::now_utc(),
536 ))
537 }
538
539 pub async fn revoke_refresh_token(&self, refresh_token: &str) -> Result<()> {
541 if refresh_token.trim().is_empty() {
542 return Err(OAuthError::Revocation(
543 "refresh token cannot be empty".into(),
544 ));
545 }
546
547 let endpoint = std::env::var(REVOKE_ENDPOINT_ENV)
548 .unwrap_or_else(|_| GOOGLE_REVOKE_ENDPOINT.to_string());
549
550 let response = self
551 .http
552 .post(&endpoint)
553 .form(&[("token", refresh_token)])
554 .send()
555 .await
556 .map_err(|e| OAuthError::Revocation(format!("revocation request failed: {}", e)))?;
557
558 if response.status().is_success() {
559 Ok(())
560 } else {
561 let status = response.status();
562 let body = response.text().await.unwrap_or_default();
563 Err(OAuthError::Revocation(format!(
564 "revocation failed (status {}): {}",
565 status, body
566 )))
567 }
568 }
569}
570
571pub struct RefreshTokenProvider<S: RefreshTokenStore> {
577 flow: OAuthFlow,
578 store: Arc<S>,
579 cache: RwLock<Option<TokenCacheEntry>>,
580 store_key: TokenStoreKey,
581}
582
583impl<S: RefreshTokenStore> RefreshTokenProvider<S> {
584 pub fn new(flow: OAuthFlow, store: Arc<S>, store_key: TokenStoreKey) -> Self {
586 Self {
587 flow,
588 store,
589 cache: RwLock::new(None),
590 store_key,
591 }
592 }
593
594 async fn ensure_tokens(&self, force_refresh: bool) -> Result<OAuthTokens> {
596 let now = OffsetDateTime::now_utc();
597
598 if !force_refresh {
600 if let Some(ref entry) = *self.cache.read() {
601 if !entry.needs_refresh(now) {
602 return Ok(entry.tokens.clone());
603 }
604 }
605 }
606
607 let stored = self
609 .store
610 .load(&self.store_key)
611 .await?
612 .ok_or_else(|| OAuthError::Config("refresh token unavailable".to_string()))?;
613
614 let tokens = self.flow.refresh(&stored.refresh_token).await?;
616 let refresh_token = tokens
617 .refresh_token
618 .clone()
619 .unwrap_or_else(|| stored.refresh_token.clone());
620 let scopes: Vec<String> = if let Some(scopes_str) = tokens.scope.as_ref() {
621 scopes_str.split_whitespace().map(String::from).collect()
622 } else if !stored.scopes.is_empty() {
623 stored.scopes.clone()
624 } else {
625 Vec::new()
626 };
627 let token_type = if !tokens.token_type.is_empty() {
628 tokens.token_type.clone()
629 } else {
630 stored.token_type.clone()
631 };
632
633 {
635 let mut cache = self.cache.write();
636 *cache = Some(TokenCacheEntry::new(tokens.clone()));
637 }
638
639 let serialized = SerializedTokens {
641 refresh_token,
642 scopes,
643 expires_at: Some(tokens.expires_at),
644 token_type,
645 updated_at: now,
646 };
647 self.store.save(&self.store_key, &serialized).await?;
648
649 Ok(tokens)
650 }
651}
652
653#[async_trait]
654impl<S: RefreshTokenStore> TokenProvider for RefreshTokenProvider<S> {
655 async fn access_token(&self) -> CoreResult<String> {
656 let tokens = self.ensure_tokens(false).await.map_err(CoreError::from)?;
657 Ok(tokens.access_token)
658 }
659
660 async fn refresh_token(&self) -> CoreResult<String> {
661 let tokens = self.ensure_tokens(true).await.map_err(CoreError::from)?;
662 Ok(tokens.access_token)
663 }
664
665 fn kind(&self) -> ProviderKind {
666 ProviderKind::UserOauth
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673 use crate::auth::oauth::testing::fake::FakeOAuthServer;
674 use tempfile::tempdir;
675 use wiremock::matchers::{method, path};
676 use wiremock::{Mock, MockServer, ResponseTemplate};
677
678 #[tokio::test]
679 async fn test_token_cache_entry_needs_refresh() {
680 let now = OffsetDateTime::now_utc();
681 let expires_at = now + Duration::from_secs(120); let tokens = OAuthTokens {
683 access_token: "test-token".to_string(),
684 refresh_token: None,
685 expires_at,
686 scope: None,
687 token_type: "Bearer".to_string(),
688 };
689
690 let entry = TokenCacheEntry::new(tokens);
691
692 assert!(!entry.needs_refresh(now));
694
695 let near_expiry = expires_at - Duration::from_secs(30);
697 assert!(entry.needs_refresh(near_expiry));
698 }
699
700 #[tokio::test]
701 async fn test_token_store_key_display() {
702 let key = TokenStoreKey {
703 profile: ApiProfile::Enterprise,
704 project_number: Some("123456".to_string()),
705 endpoint_location: Some("global".to_string()),
706 user_hint: None,
707 };
708
709 let display = key.to_string();
710 assert!(display.contains("enterprise"));
711 assert!(display.contains("project=123456"));
712 assert!(display.contains("location=global"));
713 }
714
715 #[tokio::test]
716 async fn test_oauth_tokens_from_oauth2_response() {
717 use oauth2::StandardTokenResponse;
718
719 let now = OffsetDateTime::now_utc();
720 let json_response = serde_json::json!({
722 "access_token": "access-token-123",
723 "refresh_token": "refresh-token-456",
724 "expires_in": 3600,
725 "token_type": "Bearer",
726 "scope": "scope1 scope2"
727 });
728
729 let response: StandardTokenResponse<
730 oauth2::EmptyExtraTokenFields,
731 oauth2::basic::BasicTokenType,
732 > = serde_json::from_value(json_response).unwrap();
733
734 let tokens = OAuthTokens::from_oauth2_response(response, now);
735 assert_eq!(tokens.access_token, "access-token-123");
736 assert_eq!(tokens.refresh_token, Some("refresh-token-456".to_string()));
737 assert_eq!(tokens.scope, Some("scope1 scope2".to_string()));
738 assert_eq!(tokens.token_type, "Bearer");
739 assert!(tokens.expires_at > now);
740 }
741
742 #[tokio::test]
743 async fn test_oauth_flow_refresh_token() {
744 let server = MockServer::start().await;
745
746 Mock::given(method("POST"))
747 .and(path("/token"))
748 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
749 "access_token": "new-access-token",
750 "expires_in": 3600,
751 "token_type": "Bearer"
752 })))
753 .mount(&server)
754 .await;
755
756 let config = OAuthConfig {
757 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
758 token_endpoint: format!("{}/token", server.uri()),
759 client_id: "test-client-id".to_string(),
760 client_secret: None,
761 redirect_uri: "http://127.0.0.1:4317".to_string(),
762 scopes: vec!["scope1".to_string()],
763 audience: None,
764 additional_params: HashMap::new(),
765 };
766
767 let http = Arc::new(Client::new());
768 let flow = OAuthFlow::new(config, http).unwrap();
769
770 let tokens = flow.refresh("refresh-token-123").await.unwrap();
771 assert_eq!(tokens.access_token, "new-access-token");
772 }
773
774 #[test]
776 fn test_pkce_code_verifier_and_challenge_generation() {
777 let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
778
779 let verifier_str = verifier.secret();
781 assert!(!verifier_str.is_empty());
782 assert!(verifier_str.len() >= 43 && verifier_str.len() <= 128);
783 assert!(verifier_str
784 .chars()
785 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
786
787 let challenge_str = challenge.as_str();
789 assert!(!challenge_str.is_empty());
790 assert_eq!(challenge_str.len(), 43); assert!(challenge_str
792 .chars()
793 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
794 }
795
796 #[test]
797 fn test_pkce_generates_unique_values() {
798 let (challenge1, verifier1) = PkceCodeChallenge::new_random_sha256();
799 let (challenge2, verifier2) = PkceCodeChallenge::new_random_sha256();
800
801 assert_ne!(verifier1.secret(), verifier2.secret());
803 assert_ne!(challenge1.as_str(), challenge2.as_str());
804 }
805
806 #[test]
807 fn test_state_generation_uniqueness() {
808 let state1 = CsrfToken::new_random();
809 let state2 = CsrfToken::new_random();
810
811 assert_ne!(state1.secret(), state2.secret());
813 assert!(!state1.secret().is_empty());
814 assert!(!state2.secret().is_empty());
815 }
816
817 #[test]
819 fn test_token_needs_refresh_at_exact_expiry() {
820 let now = OffsetDateTime::now_utc();
821 let expires_at = now + Duration::from_secs(60);
822 let tokens = OAuthTokens {
823 access_token: "test-token".to_string(),
824 refresh_token: None,
825 expires_at,
826 scope: None,
827 token_type: "Bearer".to_string(),
828 };
829
830 let entry = TokenCacheEntry::new(tokens);
831
832 assert!(entry.needs_refresh(expires_at));
834 }
835
836 #[test]
837 fn test_token_needs_refresh_just_before_margin() {
838 let now = OffsetDateTime::now_utc();
839 let expires_at = now + Duration::from_secs(120);
840 let tokens = OAuthTokens {
841 access_token: "test-token".to_string(),
842 refresh_token: None,
843 expires_at,
844 scope: None,
845 token_type: "Bearer".to_string(),
846 };
847
848 let entry = TokenCacheEntry::new(tokens);
849
850 let before_margin = expires_at - Duration::from_secs(61);
852 assert!(!entry.needs_refresh(before_margin));
853
854 let at_margin = expires_at - Duration::from_secs(60);
856 assert!(entry.needs_refresh(at_margin));
857 }
858
859 #[test]
860 fn test_token_needs_refresh_after_expiry() {
861 let now = OffsetDateTime::now_utc();
862 let expires_at = now - Duration::from_secs(10); let tokens = OAuthTokens {
864 access_token: "test-token".to_string(),
865 refresh_token: None,
866 expires_at,
867 scope: None,
868 token_type: "Bearer".to_string(),
869 };
870
871 let entry = TokenCacheEntry::new(tokens);
872 assert!(entry.needs_refresh(now));
873 }
874
875 #[tokio::test]
877 #[serial_test::serial]
878 async fn test_file_store_concurrent_saves() {
879 use std::sync::Arc;
880 use tokio::task::JoinSet;
881
882 let temp_dir = tempdir().unwrap();
883 let store_path = temp_dir.path().join("credentials.json");
884 let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
885 let key = TokenStoreKey {
886 profile: ApiProfile::Enterprise,
887 project_number: Some("concurrent-test".to_string()),
888 endpoint_location: Some("global".to_string()),
889 user_hint: None,
890 };
891
892 let mut join_set = JoinSet::new();
893
894 for i in 0..10 {
896 let store_clone = Arc::clone(&store);
897 let key_clone = key.clone();
898 join_set.spawn(async move {
899 let tokens = SerializedTokens {
900 refresh_token: format!("token-{}", i),
901 scopes: vec!["scope1".to_string()],
902 expires_at: Some(OffsetDateTime::now_utc() + Duration::from_secs(3600)),
903 token_type: "Bearer".to_string(),
904 updated_at: OffsetDateTime::now_utc(),
905 };
906 store_clone.save(&key_clone, &tokens).await
907 });
908 }
909
910 while let Some(result) = join_set.join_next().await {
912 result.unwrap().unwrap();
913 }
914
915 let loaded = store.load(&key).await.unwrap();
917 assert!(loaded.is_some());
918 let loaded = loaded.unwrap();
919 assert!(loaded.refresh_token.starts_with("token-"));
920
921 store.delete(&key).await.unwrap();
923 }
924
925 #[tokio::test]
926 #[serial_test::serial]
927 async fn test_file_store_atomic_write() {
928 let temp_dir = tempdir().unwrap();
929 let store_path = temp_dir.path().join("credentials.json");
930 let store = FileRefreshTokenStore::from_path(&store_path).unwrap();
931 let key = TokenStoreKey {
932 profile: ApiProfile::Enterprise,
933 project_number: Some("atomic-test".to_string()),
934 endpoint_location: Some("global".to_string()),
935 user_hint: None,
936 };
937
938 let tokens = SerializedTokens {
939 refresh_token: "initial-token".to_string(),
940 scopes: vec!["scope1".to_string()],
941 expires_at: Some(OffsetDateTime::now_utc() + Duration::from_secs(3600)),
942 token_type: "Bearer".to_string(),
943 updated_at: OffsetDateTime::now_utc(),
944 };
945
946 store.save(&key, &tokens).await.unwrap();
947
948 let entries: Vec<_> = std::fs::read_dir(temp_dir.path())
950 .unwrap()
951 .map(|entry| entry.unwrap().file_name())
952 .collect();
953 assert_eq!(entries.len(), 1);
954 assert_eq!(entries[0], std::ffi::OsStr::new("credentials.json"));
955
956 store.delete(&key).await.unwrap();
958 }
959
960 #[test]
961 #[serial_test::serial]
962 fn test_file_store_respects_custom_config_dir() {
963 use std::env;
964
965 let temp_dir = tempdir().unwrap();
966 let key = CONFIG_DIR_ENV;
967 let original = env::var(key).ok();
968 env::set_var(key, temp_dir.path());
969
970 let store = FileRefreshTokenStore::new().expect("store should honor custom dir");
971 assert!(store.file_path.starts_with(temp_dir.path()));
972
973 if let Some(value) = original {
974 env::set_var(key, value);
975 } else {
976 env::remove_var(key);
977 }
978 }
979
980 #[tokio::test]
982 #[serial_test::serial]
983 async fn test_refresh_token_preserved_when_omitted_in_response() {
984 let server = MockServer::start().await;
985
986 Mock::given(method("POST"))
988 .and(path("/token"))
989 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
990 "access_token": "new-access-token",
991 "expires_in": 3600,
992 "token_type": "Bearer"
993 })))
994 .mount(&server)
995 .await;
996
997 let config = OAuthConfig {
998 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
999 token_endpoint: format!("{}/token", server.uri()),
1000 client_id: "test-client-id".to_string(),
1001 client_secret: None,
1002 redirect_uri: "http://127.0.0.1:4317".to_string(),
1003 scopes: vec!["scope1".to_string()],
1004 audience: None,
1005 additional_params: HashMap::new(),
1006 };
1007
1008 let http = Arc::new(Client::new());
1009 let flow = OAuthFlow::new(config, http).unwrap();
1010 let temp_dir = tempdir().unwrap();
1011 let store_path = temp_dir.path().join("credentials.json");
1012 let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
1013 let key = TokenStoreKey {
1014 profile: ApiProfile::Enterprise,
1015 project_number: Some("omission-test".to_string()),
1016 endpoint_location: Some("global".to_string()),
1017 user_hint: None,
1018 };
1019
1020 let initial_tokens = SerializedTokens {
1022 refresh_token: "original-refresh-token".to_string(),
1023 scopes: vec!["scope1".to_string()],
1024 expires_at: Some(OffsetDateTime::now_utc() - Duration::from_secs(3600)), token_type: "Bearer".to_string(),
1026 updated_at: OffsetDateTime::now_utc() - Duration::from_secs(3600),
1027 };
1028 store.save(&key, &initial_tokens).await.unwrap();
1029
1030 let provider = RefreshTokenProvider::new(flow, Arc::clone(&store), key.clone());
1031
1032 let access_token = provider.access_token().await.unwrap();
1034 assert_eq!(access_token, "new-access-token");
1035
1036 let stored = store.load(&key).await.unwrap().unwrap();
1038 assert_eq!(stored.refresh_token, "original-refresh-token");
1039
1040 store.delete(&key).await.unwrap();
1042 }
1043
1044 #[tokio::test]
1045 #[serial_test::serial]
1046 async fn test_refresh_token_updated_when_included_in_response() {
1047 let server = MockServer::start().await;
1048
1049 Mock::given(method("POST"))
1051 .and(path("/token"))
1052 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1053 "access_token": "new-access-token",
1054 "refresh_token": "new-refresh-token",
1055 "expires_in": 3600,
1056 "token_type": "Bearer"
1057 })))
1058 .mount(&server)
1059 .await;
1060
1061 let config = OAuthConfig {
1062 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1063 token_endpoint: format!("{}/token", server.uri()),
1064 client_id: "test-client-id".to_string(),
1065 client_secret: None,
1066 redirect_uri: "http://127.0.0.1:4317".to_string(),
1067 scopes: vec!["scope1".to_string()],
1068 audience: None,
1069 additional_params: HashMap::new(),
1070 };
1071
1072 let http = Arc::new(Client::new());
1073 let flow = OAuthFlow::new(config, http).unwrap();
1074 let temp_dir = tempdir().unwrap();
1075 let store_path = temp_dir.path().join("credentials.json");
1076 let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
1077 let key = TokenStoreKey {
1078 profile: ApiProfile::Enterprise,
1079 project_number: Some("update-test".to_string()),
1080 endpoint_location: Some("global".to_string()),
1081 user_hint: None,
1082 };
1083
1084 let initial_tokens = SerializedTokens {
1086 refresh_token: "original-refresh-token".to_string(),
1087 scopes: vec!["scope1".to_string()],
1088 expires_at: Some(OffsetDateTime::now_utc() - Duration::from_secs(3600)), token_type: "Bearer".to_string(),
1090 updated_at: OffsetDateTime::now_utc() - Duration::from_secs(3600),
1091 };
1092 store.save(&key, &initial_tokens).await.unwrap();
1093
1094 let provider = RefreshTokenProvider::new(flow, Arc::clone(&store), key.clone());
1095
1096 let access_token = provider.access_token().await.unwrap();
1098 assert_eq!(access_token, "new-access-token");
1099
1100 let stored = store.load(&key).await.unwrap().unwrap();
1102 assert_eq!(stored.refresh_token, "new-refresh-token");
1103
1104 store.delete(&key).await.unwrap();
1106 }
1107
1108 #[tokio::test]
1110 async fn test_state_mismatch_detection() {
1111 let config = OAuthConfig {
1112 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1113 token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1114 client_id: "test-client-id".to_string(),
1115 client_secret: None,
1116 redirect_uri: "http://127.0.0.1:4317".to_string(),
1117 scopes: vec!["scope1".to_string()],
1118 audience: None,
1119 additional_params: HashMap::new(),
1120 };
1121
1122 let http = Arc::new(Client::new());
1123 let flow = OAuthFlow::new(config, http).unwrap();
1124
1125 let context = flow.build_authorize_url(&AuthorizeParams {
1126 state: None,
1127 code_challenge: None,
1128 code_challenge_method: None,
1129 });
1130
1131 let wrong_state = "attacker-controlled-state";
1133 assert_ne!(context.state, wrong_state);
1134
1135 }
1138
1139 #[test]
1140 fn test_authorize_url_contains_required_parameters() {
1141 let config = OAuthConfig {
1142 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1143 token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1144 client_id: "test-client-id".to_string(),
1145 client_secret: None,
1146 redirect_uri: "http://127.0.0.1:4317".to_string(),
1147 scopes: vec!["scope1".to_string(), "scope2".to_string()],
1148 audience: None,
1149 additional_params: HashMap::new(),
1150 };
1151
1152 let http = Arc::new(Client::new());
1153 let flow = OAuthFlow::new(config, http).unwrap();
1154
1155 let context = flow.build_authorize_url(&AuthorizeParams {
1156 state: None,
1157 code_challenge: None,
1158 code_challenge_method: None,
1159 });
1160
1161 let url = url::Url::parse(&context.url).unwrap();
1162 let params: HashMap<_, _> = url.query_pairs().collect();
1163
1164 assert!(params.contains_key("client_id"));
1166 assert!(params.contains_key("redirect_uri"));
1167 assert!(params.contains_key("response_type"));
1168 assert_eq!(params.get("response_type").unwrap(), "code");
1169 assert!(params.contains_key("scope"));
1170 assert!(params.contains_key("state"));
1171 assert!(params.contains_key("code_challenge"));
1172 assert!(params.contains_key("code_challenge_method"));
1173 assert_eq!(params.get("code_challenge_method").unwrap(), "S256");
1174 assert!(params.contains_key("access_type"));
1175 assert_eq!(params.get("access_type").unwrap(), "offline");
1176 assert!(params.contains_key("prompt"));
1177 assert_eq!(params.get("prompt").unwrap(), "consent");
1178 }
1179
1180 #[test]
1181 fn test_custom_state_is_preserved() {
1182 let config = OAuthConfig {
1183 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1184 token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1185 client_id: "test-client-id".to_string(),
1186 client_secret: None,
1187 redirect_uri: "http://127.0.0.1:4317".to_string(),
1188 scopes: vec!["scope1".to_string()],
1189 audience: None,
1190 additional_params: HashMap::new(),
1191 };
1192
1193 let http = Arc::new(Client::new());
1194 let flow = OAuthFlow::new(config, http).unwrap();
1195
1196 let custom_state = "my-custom-state-value";
1197 let context = flow.build_authorize_url(&AuthorizeParams {
1198 state: Some(custom_state.to_string()),
1199 code_challenge: None,
1200 code_challenge_method: None,
1201 });
1202
1203 assert_eq!(context.state, custom_state);
1204 assert!(context.url.contains(&format!("state={}", custom_state)));
1205 }
1206
1207 #[tokio::test]
1208 #[serial_test::serial]
1209 async fn revoke_refresh_token_succeeds_with_fake_server() {
1210 let server = FakeOAuthServer::start().await;
1211 std::env::set_var(REVOKE_ENDPOINT_ENV, server.revoke_endpoint());
1212
1213 let config = OAuthConfig {
1214 auth_endpoint: format!("{}/auth", server.base_uri()),
1215 token_endpoint: server.token_endpoint(),
1216 client_id: "client-id".into(),
1217 client_secret: None,
1218 redirect_uri: "http://localhost:1234".into(),
1219 scopes: vec!["scope".into()],
1220 audience: None,
1221 additional_params: HashMap::new(),
1222 };
1223 let flow = OAuthFlow::new(config, Arc::new(Client::new())).unwrap();
1224 flow.revoke_refresh_token("fake_refresh_token")
1225 .await
1226 .unwrap();
1227 std::env::remove_var(REVOKE_ENDPOINT_ENV);
1228 }
1229
1230 #[tokio::test]
1231 #[serial_test::serial]
1232 async fn revoke_refresh_token_reports_error() {
1233 let server = FakeOAuthServer::start().await;
1234 std::env::set_var(
1235 REVOKE_ENDPOINT_ENV,
1236 format!("{}/invalid", server.base_uri()),
1237 );
1238
1239 let config = OAuthConfig {
1240 auth_endpoint: format!("{}/auth", server.base_uri()),
1241 token_endpoint: server.token_endpoint(),
1242 client_id: "client-id".into(),
1243 client_secret: None,
1244 redirect_uri: "http://localhost:1234".into(),
1245 scopes: vec!["scope".into()],
1246 audience: None,
1247 additional_params: HashMap::new(),
1248 };
1249 let flow = OAuthFlow::new(config, Arc::new(Client::new())).unwrap();
1250 let err = flow
1251 .revoke_refresh_token("fake_refresh_token")
1252 .await
1253 .unwrap_err();
1254 assert!(matches!(err, OAuthError::Revocation(_)));
1255 std::env::remove_var(REVOKE_ENDPOINT_ENV);
1256 }
1257
1258 #[tokio::test]
1259 async fn delete_file_removes_credentials_file() {
1260 let temp = tempdir().unwrap();
1261 let path = temp.path().join("credentials.json");
1262 let store = FileRefreshTokenStore::from_path(&path).unwrap();
1263 let key = TokenStoreKey {
1264 profile: ApiProfile::Enterprise,
1265 project_number: Some("123".into()),
1266 endpoint_location: Some("global".into()),
1267 user_hint: None,
1268 };
1269 let serialized = SerializedTokens {
1270 refresh_token: "token".into(),
1271 scopes: vec![],
1272 expires_at: None,
1273 token_type: "Bearer".into(),
1274 updated_at: OffsetDateTime::now_utc(),
1275 };
1276 store.save(&key, &serialized).await.unwrap();
1277 assert!(path.exists());
1278 store.delete_file().await.unwrap();
1279 assert!(!path.exists());
1280 }
1281}