1use std::collections::HashMap;
2use std::fmt;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use oauth2::{
9 basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointSet,
10 PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse,
11 TokenResponse as OAuth2TokenResponse, TokenUrl,
12};
13use parking_lot::RwLock;
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use time::OffsetDateTime;
17
18use crate::auth::{ProviderKind, TokenProvider};
19use crate::env::ApiProfile;
20use crate::error::{Error, Result};
21
22pub mod loopback;
23
24#[derive(Debug, Clone)]
30pub struct OAuthConfig {
31 pub auth_endpoint: String,
32 pub token_endpoint: String,
33 pub client_id: String,
34 pub client_secret: Option<String>,
35 pub redirect_uri: String,
36 pub scopes: Vec<String>,
37 pub audience: Option<String>,
38 pub additional_params: HashMap<String, String>,
39}
40
41impl OAuthConfig {
42 pub const DEFAULT_REDIRECT_URI: &str = "http://127.0.0.1:4317";
43 const AUTH_ENDPOINT: &str = "https://accounts.google.com/o/oauth2/v2/auth";
44 const TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
45 const SCOPE_CLOUD_PLATFORM: &str = "https://www.googleapis.com/auth/cloud-platform";
46 const SCOPE_DRIVE_FILE: &str = "https://www.googleapis.com/auth/drive.file";
47
48 pub fn google_default(_project_number: &str) -> Result<Self> {
50 let client_id = std::env::var("NBLM_OAUTH_CLIENT_ID").map_err(|_| {
51 Error::TokenProvider(
52 "NBLM_OAUTH_CLIENT_ID is required for user OAuth authentication".to_string(),
53 )
54 })?;
55
56 let audience = std::env::var("NBLM_OAUTH_AUDIENCE").ok();
57
58 Ok(Self {
59 auth_endpoint: Self::AUTH_ENDPOINT.to_string(),
60 token_endpoint: Self::TOKEN_ENDPOINT.to_string(),
61 client_id,
62 client_secret: std::env::var("NBLM_OAUTH_CLIENT_SECRET").ok(),
63 redirect_uri: std::env::var("NBLM_OAUTH_REDIRECT_URI")
64 .unwrap_or_else(|_| Self::DEFAULT_REDIRECT_URI.to_string()),
65 scopes: vec![
66 Self::SCOPE_CLOUD_PLATFORM.to_string(),
67 Self::SCOPE_DRIVE_FILE.to_string(),
68 ],
69 audience,
70 additional_params: HashMap::new(),
71 })
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct OAuthTokens {
82 pub access_token: String,
83 pub refresh_token: Option<String>,
84 pub expires_at: OffsetDateTime,
85 pub scope: Option<String>,
86 pub token_type: String,
87}
88
89impl OAuthTokens {
90 pub fn from_oauth2_response(
92 response: StandardTokenResponse<
93 oauth2::EmptyExtraTokenFields,
94 oauth2::basic::BasicTokenType,
95 >,
96 issued_at: OffsetDateTime,
97 ) -> Self {
98 let expires_at = issued_at
99 + response
100 .expires_in()
101 .map(|d| Duration::from_secs(d.as_secs()))
102 .unwrap_or_else(|| Duration::from_secs(3600));
103
104 let scope = response.scopes().map(|scopes| {
105 scopes
106 .iter()
107 .map(|s| s.as_str().to_string())
108 .collect::<Vec<_>>()
109 .join(" ")
110 });
111
112 Self {
113 access_token: response.access_token().secret().to_string(),
114 refresh_token: response.refresh_token().map(|rt| rt.secret().to_string()),
115 expires_at,
116 scope,
117 token_type: match response.token_type() {
118 oauth2::basic::BasicTokenType::Bearer => "Bearer".to_string(),
119 oauth2::basic::BasicTokenType::Mac => "MAC".to_string(),
120 oauth2::basic::BasicTokenType::Extension(s) => s.clone(),
121 },
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
132pub struct TokenCacheEntry {
133 pub tokens: OAuthTokens,
134 pub refresh_margin: Duration,
135}
136
137impl TokenCacheEntry {
138 pub fn new(tokens: OAuthTokens) -> Self {
139 Self {
140 tokens,
141 refresh_margin: Duration::from_secs(60), }
143 }
144
145 pub fn needs_refresh(&self, now: OffsetDateTime) -> bool {
147 now >= (self.tokens.expires_at - self.refresh_margin)
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, Hash)]
157pub struct TokenStoreKey {
158 pub profile: ApiProfile,
159 pub project_number: Option<String>,
160 pub endpoint_location: Option<String>,
161 pub user_hint: Option<String>,
162}
163
164impl fmt::Display for TokenStoreKey {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 let mut parts = vec![self.profile.as_str().to_string()];
167
168 if let Some(ref project) = self.project_number {
169 parts.push(format!("project={}", project));
170 }
171
172 if let Some(ref location) = self.endpoint_location {
173 parts.push(format!("location={}", location));
174 }
175
176 if let Some(ref user) = self.user_hint {
177 parts.push(format!("user={}", user));
178 }
179
180 write!(f, "{}", parts.join(":"))
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct SerializedTokens {
191 pub refresh_token: String,
192 pub scopes: Vec<String>,
193 pub expires_at: Option<OffsetDateTime>,
194 pub token_type: String,
195 #[serde(with = "time::serde::rfc3339")]
196 pub updated_at: OffsetDateTime,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
201struct CredentialsFile {
202 version: u32,
203 entries: HashMap<String, SerializedTokens>,
204}
205
206impl CredentialsFile {
207 fn new() -> Self {
208 Self {
209 version: 1,
210 entries: HashMap::new(),
211 }
212 }
213}
214
215#[async_trait]
221pub trait RefreshTokenStore: Send + Sync {
222 async fn load(&self, key: &TokenStoreKey) -> Result<Option<SerializedTokens>>;
224
225 async fn save(&self, key: &TokenStoreKey, tokens: &SerializedTokens) -> Result<()>;
227
228 async fn delete(&self, key: &TokenStoreKey) -> Result<()>;
230}
231
232pub struct FileRefreshTokenStore {
238 file_path: std::path::PathBuf,
239}
240
241impl FileRefreshTokenStore {
242 pub fn new() -> Result<Self> {
244 let dirs = directories::ProjectDirs::from("com", "nblm", "nblm-rs")
245 .ok_or_else(|| Error::TokenProvider("failed to find config directory".to_string()))?;
246
247 let config_dir = dirs.config_dir();
248 let file_path = config_dir.join("credentials.json");
249
250 Self::from_path(file_path)
251 }
252
253 pub fn from_path(path: impl Into<PathBuf>) -> Result<Self> {
255 Ok(Self {
256 file_path: path.into(),
257 })
258 }
259
260 async fn ensure_config_dir(&self) -> Result<()> {
262 if let Some(config_dir) = self.file_path.parent() {
263 tokio::fs::create_dir_all(config_dir).await.map_err(|e| {
264 Error::TokenProvider(format!("failed to create config directory: {}", e))
265 })?;
266
267 #[cfg(unix)]
268 {
269 use std::os::unix::fs::PermissionsExt;
270 let mut perms = tokio::fs::metadata(config_dir)
271 .await
272 .map_err(|e| {
273 Error::TokenProvider(format!("failed to get config dir metadata: {}", e))
274 })?
275 .permissions();
276 perms.set_mode(0o700);
277 tokio::fs::set_permissions(config_dir, perms)
278 .await
279 .map_err(|e| {
280 Error::TokenProvider(format!("failed to set config dir permissions: {}", e))
281 })?;
282 }
283 }
284 Ok(())
285 }
286
287 async fn load_file(&self) -> Result<CredentialsFile> {
289 self.ensure_config_dir().await?;
290
291 if !self.file_path.exists() {
292 return Ok(CredentialsFile::new());
293 }
294
295 let content = tokio::fs::read_to_string(&self.file_path)
296 .await
297 .map_err(|e| Error::TokenProvider(format!("failed to read credentials file: {}", e)))?;
298
299 let file: CredentialsFile = serde_json::from_str(&content).map_err(|e| {
300 Error::TokenProvider(format!("failed to parse credentials file: {}", e))
301 })?;
302
303 Ok(file)
304 }
305
306 async fn save_file(&self, file: &CredentialsFile) -> Result<()> {
308 self.ensure_config_dir().await?;
309
310 let content = serde_json::to_string_pretty(file)
311 .map_err(|e| Error::TokenProvider(format!("failed to serialize credentials: {}", e)))?;
312
313 let random_suffix = {
316 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
317 use rand::RngCore;
318 let mut rng = rand::rng();
319 let mut random_bytes = [0u8; 8];
320 rng.fill_bytes(&mut random_bytes);
321 URL_SAFE_NO_PAD.encode(random_bytes)
322 };
323
324 let temp_path = self.file_path.with_file_name(format!(
325 "{}.{}.tmp",
326 self.file_path
327 .file_name()
328 .and_then(|n| n.to_str())
329 .unwrap_or("credentials.json"),
330 random_suffix
331 ));
332 tokio::fs::write(&temp_path, content).await.map_err(|e| {
333 Error::TokenProvider(format!("failed to write credentials file: {}", e))
334 })?;
335
336 #[cfg(unix)]
337 {
338 use std::os::unix::fs::PermissionsExt;
339 if let Ok(metadata) = tokio::fs::metadata(&temp_path).await {
340 let mut perms = metadata.permissions();
341 perms.set_mode(0o600);
342 let _ = tokio::fs::set_permissions(&temp_path, perms).await;
343 }
344 }
345
346 tokio::fs::rename(&temp_path, &self.file_path)
347 .await
348 .map_err(|e| Error::TokenProvider(format!("failed to rename temp file: {}", e)))?;
349
350 Ok(())
351 }
352}
353
354#[async_trait]
355impl RefreshTokenStore for FileRefreshTokenStore {
356 async fn load(&self, key: &TokenStoreKey) -> Result<Option<SerializedTokens>> {
357 let file = self.load_file().await?;
358 Ok(file.entries.get(&key.to_string()).cloned())
359 }
360
361 async fn save(&self, key: &TokenStoreKey, tokens: &SerializedTokens) -> Result<()> {
362 let mut file = self.load_file().await?;
363 file.entries.insert(key.to_string(), tokens.clone());
364 self.save_file(&file).await
365 }
366
367 async fn delete(&self, key: &TokenStoreKey) -> Result<()> {
368 let mut file = self.load_file().await?;
369 file.entries.remove(&key.to_string());
370 self.save_file(&file).await
371 }
372}
373
374#[derive(Debug, Clone)]
380pub struct AuthorizeParams {
381 pub state: Option<String>,
382 pub code_challenge: Option<String>,
383 pub code_challenge_method: Option<String>,
384}
385
386#[derive(Debug, Clone)]
388pub struct AuthorizeContext {
389 pub url: String,
390 pub state: String,
391 pub code_verifier: String,
392 pub expires_at: OffsetDateTime,
393}
394
395pub struct OAuthFlow {
397 client: BasicClient<
398 EndpointSet,
399 oauth2::EndpointNotSet,
400 oauth2::EndpointNotSet,
401 oauth2::EndpointNotSet,
402 EndpointSet,
403 >,
404 config: OAuthConfig,
405 http: Arc<Client>,
406}
407
408impl OAuthFlow {
409 pub fn new(config: OAuthConfig, http: Arc<Client>) -> Result<Self> {
411 let client_id = ClientId::new(config.client_id.clone());
412 let auth_url = AuthUrl::new(config.auth_endpoint.clone())
413 .map_err(|e| Error::TokenProvider(format!("invalid auth_url: {}", e)))?;
414 let token_url = TokenUrl::new(config.token_endpoint.clone())
415 .map_err(|e| Error::TokenProvider(format!("invalid token_url: {}", e)))?;
416 let redirect_url = RedirectUrl::new(config.redirect_uri.clone())
417 .map_err(|e| Error::TokenProvider(format!("invalid redirect_url: {}", e)))?;
418
419 let mut client_builder = BasicClient::new(client_id)
420 .set_auth_uri(auth_url)
421 .set_token_uri(token_url)
422 .set_redirect_uri(redirect_url);
423
424 if let Some(ref client_secret) = config.client_secret {
425 client_builder =
426 client_builder.set_client_secret(ClientSecret::new(client_secret.clone()));
427 }
428
429 Ok(Self {
430 client: client_builder,
431 config,
432 http,
433 })
434 }
435
436 pub fn build_authorize_url(&self, params: &AuthorizeParams) -> AuthorizeContext {
438 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
440
441 let csrf_token = if let Some(ref state) = params.state {
443 CsrfToken::new(state.clone())
444 } else {
445 CsrfToken::new_random()
446 };
447
448 let mut auth_request = self.client.authorize_url(|| csrf_token.clone());
450
451 for scope_str in &self.config.scopes {
453 auth_request = auth_request.add_scope(Scope::new(scope_str.clone()));
454 }
455
456 auth_request = auth_request.set_pkce_challenge(pkce_challenge);
458
459 for (key, value) in &self.config.additional_params {
461 auth_request = auth_request.add_extra_param(key, value);
462 }
463
464 let (auth_url, csrf_token_actual) = auth_request.url();
466
467 let mut url = url::Url::parse(auth_url.as_str()).expect("invalid auth_url");
469 url.query_pairs_mut()
470 .append_pair("access_type", "offline")
471 .append_pair("prompt", "consent");
472
473 if let Some(ref audience) = self.config.audience {
474 url.query_pairs_mut().append_pair("audience", audience);
475 }
476
477 let expires_at = OffsetDateTime::now_utc() + Duration::from_secs(600); AuthorizeContext {
480 url: url.to_string(),
481 state: csrf_token_actual.secret().to_string(),
482 code_verifier: pkce_verifier.secret().to_string(),
483 expires_at,
484 }
485 }
486
487 pub async fn exchange_code(
489 &self,
490 context: &AuthorizeContext,
491 code: &str,
492 ) -> Result<OAuthTokens> {
493 let code = AuthorizationCode::new(code.to_string());
494 let pkce_verifier = PkceCodeVerifier::new(context.code_verifier.clone());
495
496 let token_request = self
497 .client
498 .exchange_code(code)
499 .set_pkce_verifier(pkce_verifier);
500
501 let token_response = token_request
502 .request_async(self.http.as_ref())
503 .await
504 .map_err(|e| Error::TokenProvider(format!("oauth token exchange failed: {}", e)))?;
505
506 Ok(OAuthTokens::from_oauth2_response(
507 token_response,
508 OffsetDateTime::now_utc(),
509 ))
510 }
511
512 pub async fn refresh(&self, refresh_token: &str) -> Result<OAuthTokens> {
514 let refresh_token = RefreshToken::new(refresh_token.to_string());
515
516 let token_request = self.client.exchange_refresh_token(&refresh_token);
517
518 let token_response = token_request
519 .request_async(self.http.as_ref())
520 .await
521 .map_err(|e| Error::TokenProvider(format!("oauth token refresh failed: {}", e)))?;
522
523 Ok(OAuthTokens::from_oauth2_response(
524 token_response,
525 OffsetDateTime::now_utc(),
526 ))
527 }
528
529 pub async fn revoke(&self, _refresh_token: &str) -> Result<()> {
531 Ok(())
533 }
534}
535
536pub struct RefreshTokenProvider<S: RefreshTokenStore> {
542 flow: OAuthFlow,
543 store: Arc<S>,
544 cache: RwLock<Option<TokenCacheEntry>>,
545 store_key: TokenStoreKey,
546}
547
548impl<S: RefreshTokenStore> RefreshTokenProvider<S> {
549 pub fn new(flow: OAuthFlow, store: Arc<S>, store_key: TokenStoreKey) -> Self {
551 Self {
552 flow,
553 store,
554 cache: RwLock::new(None),
555 store_key,
556 }
557 }
558
559 async fn ensure_tokens(&self, force_refresh: bool) -> Result<OAuthTokens> {
561 let now = OffsetDateTime::now_utc();
562
563 if !force_refresh {
565 if let Some(ref entry) = *self.cache.read() {
566 if !entry.needs_refresh(now) {
567 return Ok(entry.tokens.clone());
568 }
569 }
570 }
571
572 let stored = self
574 .store
575 .load(&self.store_key)
576 .await?
577 .ok_or_else(|| Error::TokenProvider("refresh token unavailable".to_string()))?;
578
579 let tokens = self.flow.refresh(&stored.refresh_token).await?;
581 let refresh_token = tokens
582 .refresh_token
583 .clone()
584 .unwrap_or_else(|| stored.refresh_token.clone());
585 let scopes: Vec<String> = if let Some(scopes_str) = tokens.scope.as_ref() {
586 scopes_str.split_whitespace().map(String::from).collect()
587 } else if !stored.scopes.is_empty() {
588 stored.scopes.clone()
589 } else {
590 Vec::new()
591 };
592 let token_type = if !tokens.token_type.is_empty() {
593 tokens.token_type.clone()
594 } else {
595 stored.token_type.clone()
596 };
597
598 {
600 let mut cache = self.cache.write();
601 *cache = Some(TokenCacheEntry::new(tokens.clone()));
602 }
603
604 let serialized = SerializedTokens {
606 refresh_token,
607 scopes,
608 expires_at: Some(tokens.expires_at),
609 token_type,
610 updated_at: now,
611 };
612 self.store.save(&self.store_key, &serialized).await?;
613
614 Ok(tokens)
615 }
616}
617
618#[async_trait]
619impl<S: RefreshTokenStore> TokenProvider for RefreshTokenProvider<S> {
620 async fn access_token(&self) -> Result<String> {
621 let tokens = self.ensure_tokens(false).await?;
622 Ok(tokens.access_token)
623 }
624
625 async fn refresh_token(&self) -> Result<String> {
626 let tokens = self.ensure_tokens(true).await?;
628 Ok(tokens.access_token)
629 }
630
631 fn kind(&self) -> ProviderKind {
632 ProviderKind::UserOauth
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639 use tempfile::tempdir;
640 use wiremock::matchers::{method, path};
641 use wiremock::{Mock, MockServer, ResponseTemplate};
642
643 #[tokio::test]
644 async fn test_token_cache_entry_needs_refresh() {
645 let now = OffsetDateTime::now_utc();
646 let expires_at = now + Duration::from_secs(120); let tokens = OAuthTokens {
648 access_token: "test-token".to_string(),
649 refresh_token: None,
650 expires_at,
651 scope: None,
652 token_type: "Bearer".to_string(),
653 };
654
655 let entry = TokenCacheEntry::new(tokens);
656
657 assert!(!entry.needs_refresh(now));
659
660 let near_expiry = expires_at - Duration::from_secs(30);
662 assert!(entry.needs_refresh(near_expiry));
663 }
664
665 #[tokio::test]
666 async fn test_token_store_key_display() {
667 let key = TokenStoreKey {
668 profile: ApiProfile::Enterprise,
669 project_number: Some("123456".to_string()),
670 endpoint_location: Some("global".to_string()),
671 user_hint: None,
672 };
673
674 let display = key.to_string();
675 assert!(display.contains("enterprise"));
676 assert!(display.contains("project=123456"));
677 assert!(display.contains("location=global"));
678 }
679
680 #[tokio::test]
681 async fn test_oauth_tokens_from_oauth2_response() {
682 use oauth2::StandardTokenResponse;
683
684 let now = OffsetDateTime::now_utc();
685 let json_response = serde_json::json!({
687 "access_token": "access-token-123",
688 "refresh_token": "refresh-token-456",
689 "expires_in": 3600,
690 "token_type": "Bearer",
691 "scope": "scope1 scope2"
692 });
693
694 let response: StandardTokenResponse<
695 oauth2::EmptyExtraTokenFields,
696 oauth2::basic::BasicTokenType,
697 > = serde_json::from_value(json_response).unwrap();
698
699 let tokens = OAuthTokens::from_oauth2_response(response, now);
700 assert_eq!(tokens.access_token, "access-token-123");
701 assert_eq!(tokens.refresh_token, Some("refresh-token-456".to_string()));
702 assert_eq!(tokens.scope, Some("scope1 scope2".to_string()));
703 assert_eq!(tokens.token_type, "Bearer");
704 assert!(tokens.expires_at > now);
705 }
706
707 #[tokio::test]
708 async fn test_oauth_flow_refresh_token() {
709 let server = MockServer::start().await;
710
711 Mock::given(method("POST"))
712 .and(path("/token"))
713 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
714 "access_token": "new-access-token",
715 "expires_in": 3600,
716 "token_type": "Bearer"
717 })))
718 .mount(&server)
719 .await;
720
721 let config = OAuthConfig {
722 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
723 token_endpoint: format!("{}/token", server.uri()),
724 client_id: "test-client-id".to_string(),
725 client_secret: None,
726 redirect_uri: "http://127.0.0.1:4317".to_string(),
727 scopes: vec!["scope1".to_string()],
728 audience: None,
729 additional_params: HashMap::new(),
730 };
731
732 let http = Arc::new(Client::new());
733 let flow = OAuthFlow::new(config, http).unwrap();
734
735 let tokens = flow.refresh("refresh-token-123").await.unwrap();
736 assert_eq!(tokens.access_token, "new-access-token");
737 }
738
739 #[test]
741 fn test_pkce_code_verifier_and_challenge_generation() {
742 let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
743
744 let verifier_str = verifier.secret();
746 assert!(!verifier_str.is_empty());
747 assert!(verifier_str.len() >= 43 && verifier_str.len() <= 128);
748 assert!(verifier_str
749 .chars()
750 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
751
752 let challenge_str = challenge.as_str();
754 assert!(!challenge_str.is_empty());
755 assert_eq!(challenge_str.len(), 43); assert!(challenge_str
757 .chars()
758 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
759 }
760
761 #[test]
762 fn test_pkce_generates_unique_values() {
763 let (challenge1, verifier1) = PkceCodeChallenge::new_random_sha256();
764 let (challenge2, verifier2) = PkceCodeChallenge::new_random_sha256();
765
766 assert_ne!(verifier1.secret(), verifier2.secret());
768 assert_ne!(challenge1.as_str(), challenge2.as_str());
769 }
770
771 #[test]
772 fn test_state_generation_uniqueness() {
773 let state1 = CsrfToken::new_random();
774 let state2 = CsrfToken::new_random();
775
776 assert_ne!(state1.secret(), state2.secret());
778 assert!(!state1.secret().is_empty());
779 assert!(!state2.secret().is_empty());
780 }
781
782 #[test]
784 fn test_token_needs_refresh_at_exact_expiry() {
785 let now = OffsetDateTime::now_utc();
786 let expires_at = now + Duration::from_secs(60);
787 let tokens = OAuthTokens {
788 access_token: "test-token".to_string(),
789 refresh_token: None,
790 expires_at,
791 scope: None,
792 token_type: "Bearer".to_string(),
793 };
794
795 let entry = TokenCacheEntry::new(tokens);
796
797 assert!(entry.needs_refresh(expires_at));
799 }
800
801 #[test]
802 fn test_token_needs_refresh_just_before_margin() {
803 let now = OffsetDateTime::now_utc();
804 let expires_at = now + Duration::from_secs(120);
805 let tokens = OAuthTokens {
806 access_token: "test-token".to_string(),
807 refresh_token: None,
808 expires_at,
809 scope: None,
810 token_type: "Bearer".to_string(),
811 };
812
813 let entry = TokenCacheEntry::new(tokens);
814
815 let before_margin = expires_at - Duration::from_secs(61);
817 assert!(!entry.needs_refresh(before_margin));
818
819 let at_margin = expires_at - Duration::from_secs(60);
821 assert!(entry.needs_refresh(at_margin));
822 }
823
824 #[test]
825 fn test_token_needs_refresh_after_expiry() {
826 let now = OffsetDateTime::now_utc();
827 let expires_at = now - Duration::from_secs(10); let tokens = OAuthTokens {
829 access_token: "test-token".to_string(),
830 refresh_token: None,
831 expires_at,
832 scope: None,
833 token_type: "Bearer".to_string(),
834 };
835
836 let entry = TokenCacheEntry::new(tokens);
837 assert!(entry.needs_refresh(now));
838 }
839
840 #[tokio::test]
842 #[serial_test::serial]
843 async fn test_file_store_concurrent_saves() {
844 use std::sync::Arc;
845 use tokio::task::JoinSet;
846
847 let temp_dir = tempdir().unwrap();
848 let store_path = temp_dir.path().join("credentials.json");
849 let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
850 let key = TokenStoreKey {
851 profile: ApiProfile::Enterprise,
852 project_number: Some("concurrent-test".to_string()),
853 endpoint_location: Some("global".to_string()),
854 user_hint: None,
855 };
856
857 let mut join_set = JoinSet::new();
858
859 for i in 0..10 {
861 let store_clone = Arc::clone(&store);
862 let key_clone = key.clone();
863 join_set.spawn(async move {
864 let tokens = SerializedTokens {
865 refresh_token: format!("token-{}", i),
866 scopes: vec!["scope1".to_string()],
867 expires_at: Some(OffsetDateTime::now_utc() + Duration::from_secs(3600)),
868 token_type: "Bearer".to_string(),
869 updated_at: OffsetDateTime::now_utc(),
870 };
871 store_clone.save(&key_clone, &tokens).await
872 });
873 }
874
875 while let Some(result) = join_set.join_next().await {
877 result.unwrap().unwrap();
878 }
879
880 let loaded = store.load(&key).await.unwrap();
882 assert!(loaded.is_some());
883 let loaded = loaded.unwrap();
884 assert!(loaded.refresh_token.starts_with("token-"));
885
886 store.delete(&key).await.unwrap();
888 }
889
890 #[tokio::test]
891 #[serial_test::serial]
892 async fn test_file_store_atomic_write() {
893 let temp_dir = tempdir().unwrap();
894 let store_path = temp_dir.path().join("credentials.json");
895 let store = FileRefreshTokenStore::from_path(&store_path).unwrap();
896 let key = TokenStoreKey {
897 profile: ApiProfile::Enterprise,
898 project_number: Some("atomic-test".to_string()),
899 endpoint_location: Some("global".to_string()),
900 user_hint: None,
901 };
902
903 let tokens = SerializedTokens {
904 refresh_token: "initial-token".to_string(),
905 scopes: vec!["scope1".to_string()],
906 expires_at: Some(OffsetDateTime::now_utc() + Duration::from_secs(3600)),
907 token_type: "Bearer".to_string(),
908 updated_at: OffsetDateTime::now_utc(),
909 };
910
911 store.save(&key, &tokens).await.unwrap();
912
913 let entries: Vec<_> = std::fs::read_dir(temp_dir.path())
915 .unwrap()
916 .map(|entry| entry.unwrap().file_name())
917 .collect();
918 assert_eq!(entries.len(), 1);
919 assert_eq!(entries[0], std::ffi::OsStr::new("credentials.json"));
920
921 store.delete(&key).await.unwrap();
923 }
924
925 #[tokio::test]
927 #[serial_test::serial]
928 async fn test_refresh_token_preserved_when_omitted_in_response() {
929 let server = MockServer::start().await;
930
931 Mock::given(method("POST"))
933 .and(path("/token"))
934 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
935 "access_token": "new-access-token",
936 "expires_in": 3600,
937 "token_type": "Bearer"
938 })))
939 .mount(&server)
940 .await;
941
942 let config = OAuthConfig {
943 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
944 token_endpoint: format!("{}/token", server.uri()),
945 client_id: "test-client-id".to_string(),
946 client_secret: None,
947 redirect_uri: "http://127.0.0.1:4317".to_string(),
948 scopes: vec!["scope1".to_string()],
949 audience: None,
950 additional_params: HashMap::new(),
951 };
952
953 let http = Arc::new(Client::new());
954 let flow = OAuthFlow::new(config, http).unwrap();
955 let temp_dir = tempdir().unwrap();
956 let store_path = temp_dir.path().join("credentials.json");
957 let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
958 let key = TokenStoreKey {
959 profile: ApiProfile::Enterprise,
960 project_number: Some("omission-test".to_string()),
961 endpoint_location: Some("global".to_string()),
962 user_hint: None,
963 };
964
965 let initial_tokens = SerializedTokens {
967 refresh_token: "original-refresh-token".to_string(),
968 scopes: vec!["scope1".to_string()],
969 expires_at: Some(OffsetDateTime::now_utc() - Duration::from_secs(3600)), token_type: "Bearer".to_string(),
971 updated_at: OffsetDateTime::now_utc() - Duration::from_secs(3600),
972 };
973 store.save(&key, &initial_tokens).await.unwrap();
974
975 let provider = RefreshTokenProvider::new(flow, Arc::clone(&store), key.clone());
976
977 let access_token = provider.access_token().await.unwrap();
979 assert_eq!(access_token, "new-access-token");
980
981 let stored = store.load(&key).await.unwrap().unwrap();
983 assert_eq!(stored.refresh_token, "original-refresh-token");
984
985 store.delete(&key).await.unwrap();
987 }
988
989 #[tokio::test]
990 #[serial_test::serial]
991 async fn test_refresh_token_updated_when_included_in_response() {
992 let server = MockServer::start().await;
993
994 Mock::given(method("POST"))
996 .and(path("/token"))
997 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
998 "access_token": "new-access-token",
999 "refresh_token": "new-refresh-token",
1000 "expires_in": 3600,
1001 "token_type": "Bearer"
1002 })))
1003 .mount(&server)
1004 .await;
1005
1006 let config = OAuthConfig {
1007 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1008 token_endpoint: format!("{}/token", server.uri()),
1009 client_id: "test-client-id".to_string(),
1010 client_secret: None,
1011 redirect_uri: "http://127.0.0.1:4317".to_string(),
1012 scopes: vec!["scope1".to_string()],
1013 audience: None,
1014 additional_params: HashMap::new(),
1015 };
1016
1017 let http = Arc::new(Client::new());
1018 let flow = OAuthFlow::new(config, http).unwrap();
1019 let temp_dir = tempdir().unwrap();
1020 let store_path = temp_dir.path().join("credentials.json");
1021 let store = Arc::new(FileRefreshTokenStore::from_path(&store_path).unwrap());
1022 let key = TokenStoreKey {
1023 profile: ApiProfile::Enterprise,
1024 project_number: Some("update-test".to_string()),
1025 endpoint_location: Some("global".to_string()),
1026 user_hint: None,
1027 };
1028
1029 let initial_tokens = SerializedTokens {
1031 refresh_token: "original-refresh-token".to_string(),
1032 scopes: vec!["scope1".to_string()],
1033 expires_at: Some(OffsetDateTime::now_utc() - Duration::from_secs(3600)), token_type: "Bearer".to_string(),
1035 updated_at: OffsetDateTime::now_utc() - Duration::from_secs(3600),
1036 };
1037 store.save(&key, &initial_tokens).await.unwrap();
1038
1039 let provider = RefreshTokenProvider::new(flow, Arc::clone(&store), key.clone());
1040
1041 let access_token = provider.access_token().await.unwrap();
1043 assert_eq!(access_token, "new-access-token");
1044
1045 let stored = store.load(&key).await.unwrap().unwrap();
1047 assert_eq!(stored.refresh_token, "new-refresh-token");
1048
1049 store.delete(&key).await.unwrap();
1051 }
1052
1053 #[tokio::test]
1055 async fn test_state_mismatch_detection() {
1056 let config = OAuthConfig {
1057 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1058 token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1059 client_id: "test-client-id".to_string(),
1060 client_secret: None,
1061 redirect_uri: "http://127.0.0.1:4317".to_string(),
1062 scopes: vec!["scope1".to_string()],
1063 audience: None,
1064 additional_params: HashMap::new(),
1065 };
1066
1067 let http = Arc::new(Client::new());
1068 let flow = OAuthFlow::new(config, http).unwrap();
1069
1070 let context = flow.build_authorize_url(&AuthorizeParams {
1071 state: None,
1072 code_challenge: None,
1073 code_challenge_method: None,
1074 });
1075
1076 let wrong_state = "attacker-controlled-state";
1078 assert_ne!(context.state, wrong_state);
1079
1080 }
1083
1084 #[test]
1085 fn test_authorize_url_contains_required_parameters() {
1086 let config = OAuthConfig {
1087 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1088 token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1089 client_id: "test-client-id".to_string(),
1090 client_secret: None,
1091 redirect_uri: "http://127.0.0.1:4317".to_string(),
1092 scopes: vec!["scope1".to_string(), "scope2".to_string()],
1093 audience: None,
1094 additional_params: HashMap::new(),
1095 };
1096
1097 let http = Arc::new(Client::new());
1098 let flow = OAuthFlow::new(config, http).unwrap();
1099
1100 let context = flow.build_authorize_url(&AuthorizeParams {
1101 state: None,
1102 code_challenge: None,
1103 code_challenge_method: None,
1104 });
1105
1106 let url = url::Url::parse(&context.url).unwrap();
1107 let params: HashMap<_, _> = url.query_pairs().collect();
1108
1109 assert!(params.contains_key("client_id"));
1111 assert!(params.contains_key("redirect_uri"));
1112 assert!(params.contains_key("response_type"));
1113 assert_eq!(params.get("response_type").unwrap(), "code");
1114 assert!(params.contains_key("scope"));
1115 assert!(params.contains_key("state"));
1116 assert!(params.contains_key("code_challenge"));
1117 assert!(params.contains_key("code_challenge_method"));
1118 assert_eq!(params.get("code_challenge_method").unwrap(), "S256");
1119 assert!(params.contains_key("access_type"));
1120 assert_eq!(params.get("access_type").unwrap(), "offline");
1121 assert!(params.contains_key("prompt"));
1122 assert_eq!(params.get("prompt").unwrap(), "consent");
1123 }
1124
1125 #[test]
1126 fn test_custom_state_is_preserved() {
1127 let config = OAuthConfig {
1128 auth_endpoint: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1129 token_endpoint: "https://oauth2.googleapis.com/token".to_string(),
1130 client_id: "test-client-id".to_string(),
1131 client_secret: None,
1132 redirect_uri: "http://127.0.0.1:4317".to_string(),
1133 scopes: vec!["scope1".to_string()],
1134 audience: None,
1135 additional_params: HashMap::new(),
1136 };
1137
1138 let http = Arc::new(Client::new());
1139 let flow = OAuthFlow::new(config, http).unwrap();
1140
1141 let custom_state = "my-custom-state-value";
1142 let context = flow.build_authorize_url(&AuthorizeParams {
1143 state: Some(custom_state.to_string()),
1144 code_challenge: None,
1145 code_challenge_method: None,
1146 });
1147
1148 assert_eq!(context.state, custom_state);
1149 assert!(context.url.contains(&format!("state={}", custom_state)));
1150 }
1151}