Skip to main content

stack_auth/
auto_refresh.rs

1use tokio::sync::Mutex;
2
3use crate::refresher::Refresher;
4use crate::{ServiceToken, Token};
5
6/// Internal errors from [`AutoRefresh::get_token`].
7///
8/// Strategy wrappers convert these into [`AuthError`](crate::AuthError) for the
9/// public API.
10#[derive(Debug, thiserror::Error)]
11pub(crate) enum AutoRefreshError {
12    /// No token is cached and the strategy cannot self-authenticate.
13    #[error("No token found")]
14    NotFound,
15    /// The token has expired and refresh failed or is unavailable.
16    #[error("Token has expired")]
17    Expired,
18    /// The refresh/auth HTTP call failed.
19    #[error("Auth error: {0}")]
20    Auth(#[from] crate::AuthError),
21}
22
23impl From<AutoRefreshError> for crate::AuthError {
24    fn from(err: AutoRefreshError) -> Self {
25        match err {
26            AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
27            AutoRefreshError::Expired => crate::AuthError::TokenExpired,
28            AutoRefreshError::Auth(e) => e,
29        }
30    }
31}
32
33/// Caches a token in memory and uses a [`Refresher`] to re-authenticate
34/// or refresh before expiry.
35///
36/// # Concurrency model
37///
38/// Internal state is protected by a [`tokio::sync::Mutex`]. The key design
39/// decision is *when* the lock is held during a refresh, which depends on
40/// whether the current token is still usable as a bearer credential:
41///
42/// - [`Token::is_expired()`] — returns `true` when the token is within **90
43///   seconds** of its `expires_at` timestamp. This triggers a preemptive
44///   refresh attempt.
45/// - [`Token::is_usable()`] — returns `true` when the token has **not yet
46///   reached** its `expires_at` timestamp. A token can be "expired" (in the
47///   leeway sense) but still "usable" (the server will still accept it).
48///
49/// This distinction enables two concurrent refresh strategies:
50///
51/// 1. **Expiring but still usable** — The refreshing caller drops the lock
52///    before making the HTTP request. Concurrent callers acquire the lock and
53///    receive the current (still-valid) token immediately.
54/// 2. **Fully expired** — The refreshing caller holds the lock through the
55///    HTTP request. Concurrent callers block on `lock().await` until the
56///    refresh completes, then see the new token.
57///
58/// Cascade prevention: the `refresh_in_progress` flag prevents multiple
59/// callers from initiating concurrent refreshes.
60///
61/// # Flow diagram
62///
63/// ```mermaid
64/// flowchart TD
65///     Start["get_token()"] --> Lock["Acquire lock"]
66///     Lock --> Cached{Token cached?}
67///     Cached -- No --> TryCred0["try_credential(None)"]
68///     TryCred0 -- None --> ErrNotFound["Return NotFound"]
69///     TryCred0 -- "Some(cred)" --> InitAuth["refresh(cred)
70///     (lock HELD)"]
71///     InitAuth -- OK --> SaveInit["save + cache token"]
72///     SaveInit --> ReturnNew["Return Ok(new token)"]
73///     InitAuth -- Err --> ErrAuth["Return Auth(err)"]
74///     Cached -- Yes --> CheckRefresh{is_expired?}
75///
76///     CheckRefresh -- "No (fresh)" --> CloneFresh["Clone access token,
77///     release lock"]
78///     CloneFresh --> ReturnOk["Return Ok(token)"]
79///
80///     CheckRefresh -- "Yes (needs refresh)" --> InProgress{refresh_in_progress?}
81///     InProgress -- Yes --> Usable0{is_usable?}
82///     Usable0 -- Yes --> CloneUsable0["Clone access token"]
83///     CloneUsable0 --> ReturnOk
84///     Usable0 -- No --> ErrExpired["Return Expired"]
85///
86///     InProgress -- No --> TryCred{try_credential}
87///     TryCred -- None --> Usable1{is_usable?}
88///     Usable1 -- Yes --> CloneUsable1["Clone access token"]
89///     CloneUsable1 --> ReturnOk
90///     Usable1 -- No --> ErrExpired
91///
92///     TryCred -- "Some(cred)" --> SetFlag["refresh_in_progress = true"]
93///     SetFlag --> Usable2{is_usable?}
94///
95///     Usable2 -- "Yes (expiring but usable)" --> DropLock["Clone access token,
96///     release lock"]
97///     DropLock --> HTTP1["refresh(cred)
98///     (lock NOT held)"]
99///     HTTP1 -- OK --> Relock1["Re-acquire lock,
100///     save + cache, clear flag"]
101///     HTTP1 -- Err --> Restore1["Restore credential,
102///     clear flag"]
103///     Relock1 --> ReturnOld["Return Ok(old token)"]
104///     Restore1 --> ReturnOld
105///
106///     Usable2 -- "No (fully expired)" --> HTTP2["refresh(cred)
107///     (lock HELD)"]
108///     HTTP2 -- OK --> StoreNew["save + cache,
109///     clear flag, release lock"]
110///     StoreNew --> ReturnNew2["Return Ok(new token)"]
111///     HTTP2 -- Err --> Restore2["Restore credential,
112///     clear flag"]
113///     Restore2 --> ErrExpired
114/// ```
115#[cfg_attr(doc, aquamarine::aquamarine)]
116pub(crate) struct AutoRefresh<R> {
117    refresher: R,
118    state: Mutex<State>,
119}
120
121struct State {
122    token: Option<Token>,
123    refresh_in_progress: bool,
124}
125
126impl<R> AutoRefresh<R> {
127    /// Create a new `AutoRefresh` with no initial token.
128    ///
129    /// The first call to `get_token` will attempt initial authentication via
130    /// `try_credential(None)` → `refresh()`. Use this for refreshers that can
131    /// self-authenticate (e.g. access keys).
132    pub(crate) fn new(refresher: R) -> Self {
133        Self {
134            refresher,
135            state: Mutex::new(State {
136                token: None,
137                refresh_in_progress: false,
138            }),
139        }
140    }
141
142    /// Create a new `AutoRefresh` with a pre-loaded token.
143    ///
144    /// Use this for refreshers that cannot self-authenticate (e.g. OAuth,
145    /// which needs a refresh token from a prior device code flow).
146    pub(crate) fn with_token(refresher: R, token: Token) -> Self {
147        Self {
148            refresher,
149            state: Mutex::new(State {
150                token: Some(token),
151                refresh_in_progress: false,
152            }),
153        }
154    }
155}
156
157impl<R: Refresher> AutoRefresh<R> {
158    /// Retrieve a valid access token, refreshing or re-authenticating as needed.
159    pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
160        let mut state = self.state.lock().await;
161
162        // No cached token — attempt initial auth.
163        if state.token.is_none() {
164            let Some(credential) = self.refresher.try_credential(None) else {
165                return Err(AutoRefreshError::NotFound);
166            };
167            state.refresh_in_progress = true;
168            match self.refresher.refresh(&credential).await {
169                Ok(new_token) => {
170                    self.refresher.save(&new_token);
171                    let service_token = ServiceToken::new(new_token.access_token().clone());
172                    state.token = Some(new_token);
173                    state.refresh_in_progress = false;
174                    return Ok(service_token);
175                }
176                Err(err) => {
177                    state.refresh_in_progress = false;
178                    return Err(AutoRefreshError::Auth(err));
179                }
180            }
181        }
182
183        let needs_refresh = state.token.as_ref().is_some_and(|t| t.is_expired());
184        if !needs_refresh {
185            // Token is fresh — clone and return.
186            let token = state.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
187            return Ok(ServiceToken::new(token.access_token().clone()));
188        }
189
190        // Check cascade prevention flag.
191        if state.refresh_in_progress {
192            let token = state.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
193            if token.is_usable() {
194                return Ok(ServiceToken::new(token.access_token().clone()));
195            }
196            // NOTE: If a refresh was started while the token was still usable
197            // (lock released) but the token has since crossed its real expiry,
198            // we return Expired rather than waiting for the in-flight refresh.
199            // This is a deliberate trade-off: adding a Notify/condvar to wait
200            // for the in-flight refresh would increase complexity, and the
201            // window is narrow (token must expire during the HTTP call). The
202            // 90s leeway on is_expired() makes this unlikely. Callers can
203            // retry and will get the new token once the refresh completes.
204            return Err(AutoRefreshError::Expired);
205        }
206
207        // Token needs refresh. Try to get a credential.
208        let credential = self.refresher.try_credential(state.token.as_mut());
209
210        let Some(credential) = credential else {
211            // No credential available (e.g. OAuth with no refresh token).
212            let token = state.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
213            if token.is_usable() {
214                return Ok(ServiceToken::new(token.access_token().clone()));
215            }
216            return Err(AutoRefreshError::Expired);
217        };
218
219        state.refresh_in_progress = true;
220
221        // Check if the current token is still usable.
222        let is_usable = state.token.as_ref().is_some_and(|t| t.is_usable());
223
224        if is_usable {
225            // Token is expiring but still usable. Clone the current access
226            // token, drop the lock, and refresh in the background of this call.
227            let current_service_token = ServiceToken::new(
228                state
229                    .token
230                    .as_ref()
231                    .ok_or(AutoRefreshError::NotFound)?
232                    .access_token()
233                    .clone(),
234            );
235            drop(state);
236
237            match self.refresher.refresh(&credential).await {
238                Ok(new_token) => {
239                    self.refresher.save(&new_token);
240                    let mut state = self.state.lock().await;
241                    state.token = Some(new_token);
242                    state.refresh_in_progress = false;
243                }
244                Err(err) => {
245                    tracing::warn!(%err, "token refresh failed (token still usable)");
246                    let mut state = self.state.lock().await;
247                    if let Some(token) = state.token.as_mut() {
248                        self.refresher.restore(token, credential);
249                    }
250                    state.refresh_in_progress = false;
251                }
252            }
253
254            Ok(current_service_token)
255        } else {
256            // Token is fully expired. Refresh while holding the lock.
257            match self.refresher.refresh(&credential).await {
258                Ok(new_token) => {
259                    self.refresher.save(&new_token);
260                    let service_token = ServiceToken::new(new_token.access_token().clone());
261                    state.token = Some(new_token);
262                    state.refresh_in_progress = false;
263                    Ok(service_token)
264                }
265                Err(err) => {
266                    tracing::warn!(%err, "token refresh failed");
267                    if let Some(token) = state.token.as_mut() {
268                        self.refresher.restore(token, credential);
269                    }
270                    state.refresh_in_progress = false;
271                    Err(AutoRefreshError::Expired)
272                }
273            }
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::oauth_refresher::OAuthRefresher;
282    use crate::SecretToken;
283    use mocktail::prelude::*;
284    use stack_profile::ProfileStore;
285    use std::sync::Arc;
286    use std::time::{SystemTime, UNIX_EPOCH};
287
288    fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
289        let now = SystemTime::now()
290            .duration_since(UNIX_EPOCH)
291            .unwrap()
292            .as_secs();
293
294        Token {
295            access_token: SecretToken::new(access),
296            token_type: "Bearer".to_string(),
297            expires_at: now + expires_in,
298            refresh_token: if refresh {
299                Some(SecretToken::new("test-refresh-token"))
300            } else {
301                None
302            },
303            region: None,
304            client_id: None,
305            device_instance_id: None,
306        }
307    }
308
309    fn refresh_response_json(access: &str) -> serde_json::Value {
310        serde_json::json!({
311            "access_token": access,
312            "token_type": "Bearer",
313            "expires_in": 3600,
314            "refresh_token": "new-refresh-token"
315        })
316    }
317
318    fn error_json(error: &str) -> serde_json::Value {
319        serde_json::json!({
320            "error": error,
321            "error_description": format!("{error} occurred")
322        })
323    }
324
325    async fn start_server(mocks: MockSet) -> MockServer {
326        let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
327        server.start().await.unwrap();
328        server
329    }
330
331    fn auto_refresh_with_token(
332        dir: &tempfile::TempDir,
333        server: &MockServer,
334        token: Token,
335    ) -> AutoRefresh<OAuthRefresher> {
336        let store = ProfileStore::new(dir.path());
337        store.save_profile(&token).unwrap();
338        let refresher = OAuthRefresher::new(
339            Some(store),
340            server.url(""),
341            "cli",
342            "ap-southeast-2.aws",
343            None,
344        );
345        AutoRefresh::with_token(refresher, token)
346    }
347
348    // ---- Basic loading tests ----
349
350    #[tokio::test]
351    async fn test_returns_cached_token() {
352        let dir = tempfile::tempdir().unwrap();
353        let server = start_server(MockSet::new()).await;
354        let strategy =
355            auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
356
357        let token = strategy.get_token().await.unwrap();
358
359        assert_eq!(token.as_str(), "my-access-token");
360    }
361
362    #[tokio::test]
363    async fn test_returns_not_found_when_no_token_and_oauth() {
364        let server = start_server(MockSet::new()).await;
365        let store = ProfileStore::new("/tmp/nonexistent");
366        let refresher = OAuthRefresher::new(
367            Some(store),
368            server.url(""),
369            "cli",
370            "ap-southeast-2.aws",
371            None,
372        );
373        let strategy = AutoRefresh::new(refresher);
374
375        let err = strategy.get_token().await.unwrap_err();
376
377        assert!(matches!(err, AutoRefreshError::NotFound));
378    }
379
380    #[tokio::test]
381    async fn test_caches_token_across_calls() {
382        let dir = tempfile::tempdir().unwrap();
383        let server = start_server(MockSet::new()).await;
384        let strategy =
385            auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
386
387        let token1 = strategy.get_token().await.unwrap();
388        assert_eq!(token1.as_str(), "my-access-token");
389
390        // Delete the file — second call should still return the cached token.
391        std::fs::remove_file(dir.path().join("auth.json")).unwrap();
392
393        let token2 = strategy.get_token().await.unwrap();
394        assert_eq!(token2.as_str(), "my-access-token");
395    }
396
397    // ---- Expiry tests ----
398
399    #[tokio::test]
400    async fn test_expired_token_without_refresh_token_returns_expired() {
401        let dir = tempfile::tempdir().unwrap();
402        let server = start_server(MockSet::new()).await;
403        let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
404
405        let err = strategy.get_token().await.unwrap_err();
406
407        assert!(matches!(err, AutoRefreshError::Expired));
408    }
409
410    // ---- Refresh tests ----
411
412    #[tokio::test]
413    async fn test_refreshes_expiring_token() {
414        let mut mocks = MockSet::new();
415        mocks.mock(|when, then| {
416            when.post().path("/oauth/token");
417            then.json(refresh_response_json("refreshed-token"));
418        });
419        let server = start_server(mocks).await;
420        let dir = tempfile::tempdir().unwrap();
421        let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
422
423        let token = strategy.get_token().await.unwrap();
424
425        assert_eq!(token.as_str(), "refreshed-token");
426    }
427
428    #[tokio::test]
429    async fn test_refresh_persists_new_token_to_disk() {
430        let mut mocks = MockSet::new();
431        mocks.mock(|when, then| {
432            when.post().path("/oauth/token");
433            then.json(refresh_response_json("refreshed-token"));
434        });
435        let server = start_server(mocks).await;
436        let dir = tempfile::tempdir().unwrap();
437        let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
438
439        let _ = strategy.get_token().await.unwrap();
440
441        // Verify the refreshed token was saved to disk.
442        let store = ProfileStore::new(dir.path());
443        let on_disk: Token = store.load_profile().unwrap();
444        assert_eq!(on_disk.access_token().as_str(), "refreshed-token");
445    }
446
447    #[tokio::test]
448    async fn test_refresh_failure_returns_expired_when_token_is_expired() {
449        let mut mocks = MockSet::new();
450        mocks.mock(|when, then| {
451            when.post().path("/oauth/token");
452            then.bad_request().json(error_json("invalid_grant"));
453        });
454        let server = start_server(mocks).await;
455        let dir = tempfile::tempdir().unwrap();
456        let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
457
458        let err = strategy.get_token().await.unwrap_err();
459
460        assert!(matches!(err, AutoRefreshError::Expired));
461    }
462
463    #[tokio::test]
464    async fn test_does_not_refresh_fresh_token() {
465        // Mock that would fail if hit — proves no refresh request is made.
466        let mut mocks = MockSet::new();
467        mocks.mock(|when, then| {
468            when.post().path("/oauth/token");
469            then.internal_server_error()
470                .json(error_json("should_not_be_called"));
471        });
472        let server = start_server(mocks).await;
473        let dir = tempfile::tempdir().unwrap();
474        let strategy =
475            auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
476
477        let token = strategy.get_token().await.unwrap();
478
479        assert_eq!(token.as_str(), "fresh-token");
480    }
481
482    // ---- Cascade prevention tests ----
483
484    #[tokio::test]
485    async fn test_refresh_token_is_taken_preventing_second_refresh() {
486        let mut mocks = MockSet::new();
487        mocks.mock(|when, then| {
488            when.post().path("/oauth/token");
489            then.json(refresh_response_json("refreshed-token"));
490        });
491        let server = start_server(mocks).await;
492        let dir = tempfile::tempdir().unwrap();
493        let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
494
495        // First call refreshes successfully.
496        let token = strategy.get_token().await.unwrap();
497        assert_eq!(token.as_str(), "refreshed-token");
498
499        // Replace the mock with one that errors.
500        server.mocks().clear();
501        server.mocks().mock(|when, then| {
502            when.post().path("/oauth/token");
503            then.bad_request().json(error_json("should_not_be_called"));
504        });
505
506        // Second call should return the refreshed token without hitting
507        // the server again (the new token has a fresh expiry).
508        let token = strategy.get_token().await.unwrap();
509        assert_eq!(token.as_str(), "refreshed-token");
510    }
511
512    #[tokio::test]
513    async fn test_failed_refresh_restores_refresh_token_for_retry() {
514        let mut mocks = MockSet::new();
515        mocks.mock(|when, then| {
516            when.post().path("/oauth/token");
517            then.bad_request().json(error_json("invalid_grant"));
518        });
519        let server = start_server(mocks).await;
520        let dir = tempfile::tempdir().unwrap();
521        let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
522
523        // First call: refresh fails, returns Expired.
524        let err = strategy.get_token().await.unwrap_err();
525        assert!(matches!(err, AutoRefreshError::Expired));
526
527        // Verify the refresh token was restored so a retry is possible.
528        let state = strategy.state.lock().await;
529        assert!(state.token.is_some());
530        assert!(state.token.as_ref().unwrap().refresh_token().is_some());
531        drop(state);
532
533        // Replace mock with a success response.
534        server.mocks().clear();
535        server.mocks().mock(|when, then| {
536            when.post().path("/oauth/token");
537            then.json(refresh_response_json("refreshed-token"));
538        });
539
540        // Second call: refresh token is available → retry succeeds.
541        let token = strategy.get_token().await.unwrap();
542        assert_eq!(token.as_str(), "refreshed-token");
543    }
544
545    #[tokio::test]
546    async fn test_access_token_remains_after_refresh_token_is_taken() {
547        let mut mocks = MockSet::new();
548        mocks.mock(|when, then| {
549            when.post().path("/oauth/token");
550            then.bad_request().json(error_json("server_error"));
551        });
552        let server = start_server(mocks).await;
553        let dir = tempfile::tempdir().unwrap();
554        // Token expires in 30s (within the 90s leeway so is_expired() = true),
555        // but the access token is still technically usable.
556        let strategy = auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
557
558        // The refresh fails, but the access token should still be returned
559        // because it's still usable (30s remaining > 0).
560        let token = strategy.get_token().await.unwrap();
561        assert_eq!(token.as_str(), "still-usable");
562
563        // Verify the access token and refresh token are still present.
564        let state = strategy.state.lock().await;
565        assert!(state.token.is_some());
566        assert_eq!(
567            state.token.as_ref().unwrap().access_token().as_str(),
568            "still-usable"
569        );
570        assert!(
571            state.token.as_ref().unwrap().refresh_token().is_some(),
572            "refresh token should be restored after failed refresh"
573        );
574    }
575
576    #[tokio::test]
577    async fn test_failed_refresh_of_usable_token_can_be_retried() {
578        let mut mocks = MockSet::new();
579        mocks.mock(|when, then| {
580            when.post().path("/oauth/token");
581            then.bad_request().json(error_json("server_error"));
582        });
583        let server = start_server(mocks).await;
584        let dir = tempfile::tempdir().unwrap();
585        // Token expires in 30s — is_expired() = true, is_usable() = true.
586        let strategy = auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
587
588        // First call: refresh fails, but the still-usable token is returned.
589        let token = strategy.get_token().await.unwrap();
590        assert_eq!(token.as_str(), "still-usable");
591
592        // Replace mock with a success response.
593        server.mocks().clear();
594        server.mocks().mock(|when, then| {
595            when.post().path("/oauth/token");
596            then.json(refresh_response_json("refreshed-token"));
597        });
598
599        // Second call: refresh token was restored, so the retry succeeds.
600        let token = strategy.get_token().await.unwrap();
601        assert!(
602            token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
603            "expected old or refreshed token, got: {}",
604            token.as_str()
605        );
606
607        // Verify the cache now holds the refreshed token.
608        let state = strategy.state.lock().await;
609        assert_eq!(
610            state.token.as_ref().unwrap().access_token().as_str(),
611            "refreshed-token"
612        );
613    }
614
615    #[tokio::test]
616    async fn test_multiple_sequential_calls_only_refresh_once() {
617        let mut mocks = MockSet::new();
618        mocks.mock(|when, then| {
619            when.post().path("/oauth/token");
620            then.json(refresh_response_json("refreshed-once"));
621        });
622        let server = start_server(mocks).await;
623        let dir = tempfile::tempdir().unwrap();
624        let strategy = auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
625
626        // First call triggers refresh.
627        let token = strategy.get_token().await.unwrap();
628        assert_eq!(token.as_str(), "refreshed-once");
629
630        // Swap mock to track if another refresh is attempted.
631        server.mocks().clear();
632        server.mocks().mock(|when, then| {
633            when.post().path("/oauth/token");
634            then.json(refresh_response_json("refreshed-twice"));
635        });
636
637        // Calls 2-5: the refreshed token is fresh, so no further refresh.
638        for _ in 0..4 {
639            let token = strategy.get_token().await.unwrap();
640            assert_eq!(
641                token.as_str(),
642                "refreshed-once",
643                "should return cached refreshed token, not trigger another refresh"
644            );
645        }
646    }
647
648    // ---- Concurrent access tests ----
649
650    #[tokio::test]
651    async fn test_concurrent_access_with_expiring_but_usable_token() {
652        let mut mocks = MockSet::new();
653        mocks.mock(|when, then| {
654            when.post().path("/oauth/token");
655            then.json(refresh_response_json("refreshed-token"));
656        });
657        let server = start_server(mocks).await;
658        let dir = tempfile::tempdir().unwrap();
659        let strategy = Arc::new(auto_refresh_with_token(
660            &dir,
661            &server,
662            make_token("still-usable", 30, true),
663        ));
664
665        let s1 = Arc::clone(&strategy);
666        let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
667
668        let s2 = Arc::clone(&strategy);
669        let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
670
671        let (result_a, result_b) = tokio::join!(handle_a, handle_b);
672        let token_a = result_a.unwrap();
673        let token_b = result_b.unwrap();
674
675        assert!(
676            token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
677            "unexpected token_a: {}",
678            token_a.as_str()
679        );
680        assert!(
681            token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
682            "unexpected token_b: {}",
683            token_b.as_str()
684        );
685    }
686
687    #[tokio::test]
688    async fn test_concurrent_access_with_fully_expired_token() {
689        let mut mocks = MockSet::new();
690        mocks.mock(|when, then| {
691            when.post().path("/oauth/token");
692            then.json(refresh_response_json("refreshed-token"));
693        });
694        let server = start_server(mocks).await;
695        let dir = tempfile::tempdir().unwrap();
696        let strategy = Arc::new(auto_refresh_with_token(
697            &dir,
698            &server,
699            make_token("expired-token", 0, true),
700        ));
701
702        let s1 = Arc::clone(&strategy);
703        let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
704
705        let s2 = Arc::clone(&strategy);
706        let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
707
708        let (result_a, result_b) = tokio::join!(handle_a, handle_b);
709        let token_a = result_a.unwrap();
710        let token_b = result_b.unwrap();
711
712        assert_eq!(token_a.as_str(), "refreshed-token");
713        assert_eq!(token_b.as_str(), "refreshed-token");
714    }
715}
716
717#[cfg(test)]
718mod stress_tests {
719    use super::*;
720    use crate::oauth_refresher::OAuthRefresher;
721    use crate::SecretToken;
722    use stack_profile::ProfileStore;
723    use std::sync::atomic::{AtomicUsize, Ordering};
724    use std::sync::Arc;
725    use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
726
727    /// Tracks in-flight and peak concurrency for test assertions.
728    #[derive(Clone)]
729    struct CountingState {
730        total: Arc<AtomicUsize>,
731        current: Arc<AtomicUsize>,
732        peak: Arc<AtomicUsize>,
733    }
734
735    impl CountingState {
736        fn new() -> Self {
737            Self {
738                total: Arc::new(AtomicUsize::new(0)),
739                current: Arc::new(AtomicUsize::new(0)),
740                peak: Arc::new(AtomicUsize::new(0)),
741            }
742        }
743
744        fn enter(&self) {
745            self.total.fetch_add(1, Ordering::SeqCst);
746            let prev = self.current.fetch_add(1, Ordering::SeqCst);
747            self.peak.fetch_max(prev + 1, Ordering::SeqCst);
748        }
749
750        fn exit(&self) {
751            self.current.fetch_sub(1, Ordering::SeqCst);
752        }
753
754        fn peak(&self) -> usize {
755            self.peak.load(Ordering::SeqCst)
756        }
757
758        fn total(&self) -> usize {
759            self.total.load(Ordering::SeqCst)
760        }
761    }
762
763    #[derive(Clone)]
764    struct DelayedRefreshState {
765        counting: CountingState,
766        delay: Duration,
767    }
768
769    async fn delayed_refresh_handler(
770        axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
771    ) -> axum::Json<serde_json::Value> {
772        state.counting.enter();
773        tokio::time::sleep(state.delay).await;
774        state.counting.exit();
775        axum::Json(serde_json::json!({
776            "access_token": "refreshed-token",
777            "token_type": "Bearer",
778            "expires_in": 3600,
779            "refresh_token": "new-refresh-token"
780        }))
781    }
782
783    async fn delayed_error_handler(
784        axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
785    ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
786        state.counting.enter();
787        tokio::time::sleep(state.delay).await;
788        state.counting.exit();
789        (
790            axum::http::StatusCode::BAD_REQUEST,
791            axum::Json(serde_json::json!({
792                "error": "invalid_grant",
793                "error_description": "invalid_grant occurred"
794            })),
795        )
796    }
797
798    async fn start_axum_server<H, T>(
799        handler: H,
800        state: DelayedRefreshState,
801    ) -> (url::Url, CountingState)
802    where
803        H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
804        T: 'static,
805    {
806        let counting = state.counting.clone();
807        let app = axum::Router::new()
808            .route("/oauth/token", axum::routing::post(handler))
809            .with_state(state);
810        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
811        let addr = listener.local_addr().unwrap();
812        tokio::spawn(async move {
813            axum::serve(listener, app).await.unwrap();
814        });
815        let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
816        (base_url, counting)
817    }
818
819    fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
820        let now = SystemTime::now()
821            .duration_since(UNIX_EPOCH)
822            .unwrap()
823            .as_secs();
824
825        Token {
826            access_token: SecretToken::new(access),
827            token_type: "Bearer".to_string(),
828            expires_at: now + expires_in,
829            refresh_token: if refresh {
830                Some(SecretToken::new("test-refresh-token"))
831            } else {
832                None
833            },
834            region: None,
835            client_id: None,
836            device_instance_id: None,
837        }
838    }
839
840    fn auto_refresh_with_token(
841        dir: &tempfile::TempDir,
842        base_url: &url::Url,
843        token: Token,
844    ) -> AutoRefresh<OAuthRefresher> {
845        let store = ProfileStore::new(dir.path());
846        store.save_profile(&token).unwrap();
847        let refresher = OAuthRefresher::new(
848            Some(store),
849            base_url.clone(),
850            "cli",
851            "ap-southeast-2.aws",
852            None,
853        );
854        AutoRefresh::with_token(refresher, token)
855    }
856
857    const CONCURRENCY: usize = 50;
858
859    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
860    async fn test_concurrent_fresh_token_no_contention() {
861        let counting = CountingState::new();
862        let state = DelayedRefreshState {
863            counting: counting.clone(),
864            delay: Duration::from_millis(500),
865        };
866        let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
867        let dir = tempfile::tempdir().unwrap();
868        let strategy = Arc::new(auto_refresh_with_token(
869            &dir,
870            &base_url,
871            make_token("fresh-token", 3600, true),
872        ));
873
874        let start = Instant::now();
875        let mut handles = Vec::with_capacity(CONCURRENCY);
876        for _ in 0..CONCURRENCY {
877            let s = Arc::clone(&strategy);
878            handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
879        }
880
881        let results: Vec<_> = {
882            let mut results = Vec::with_capacity(handles.len());
883            for handle in handles {
884                results.push(handle.await.unwrap());
885            }
886            results
887        };
888        let elapsed = start.elapsed();
889
890        for token in &results {
891            assert_eq!(token.as_str(), "fresh-token");
892        }
893
894        assert!(
895            elapsed < Duration::from_millis(200),
896            "expected < 200ms for fresh tokens, got {:?}",
897            elapsed
898        );
899        assert_eq!(stats.total(), 0, "no refresh requests should be made");
900    }
901
902    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
903    async fn test_concurrent_expiring_token_non_blocking_reads() {
904        let counting = CountingState::new();
905        let state = DelayedRefreshState {
906            counting: counting.clone(),
907            delay: Duration::from_millis(500),
908        };
909        let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
910        let dir = tempfile::tempdir().unwrap();
911        let strategy = Arc::new(auto_refresh_with_token(
912            &dir,
913            &base_url,
914            make_token("still-usable", 30, true),
915        ));
916
917        let start = Instant::now();
918        let mut handles = Vec::with_capacity(CONCURRENCY);
919        for _ in 0..CONCURRENCY {
920            let s = Arc::clone(&strategy);
921            handles.push(tokio::spawn(async move {
922                let call_start = Instant::now();
923                let token = s.get_token().await.unwrap();
924                (token, call_start.elapsed())
925            }));
926        }
927
928        let results: Vec<_> = {
929            let mut results = Vec::with_capacity(handles.len());
930            for handle in handles {
931                results.push(handle.await.unwrap());
932            }
933            results
934        };
935        let elapsed = start.elapsed();
936
937        for (token, _) in &results {
938            assert!(
939                token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
940                "unexpected token: {}",
941                token.as_str()
942            );
943        }
944
945        let fast_callers = results
946            .iter()
947            .filter(|(_, dur)| *dur < Duration::from_millis(100))
948            .count();
949        assert!(
950            fast_callers >= CONCURRENCY - 1,
951            "expected at least {} fast callers, got {} (total elapsed: {:?})",
952            CONCURRENCY - 1,
953            fast_callers,
954            elapsed
955        );
956
957        assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
958        assert_eq!(stats.total(), 1, "total refresh requests");
959    }
960
961    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
962    async fn test_concurrent_expired_token_blocks_until_refresh() {
963        let refresh_delay = Duration::from_millis(200);
964        let counting = CountingState::new();
965        let state = DelayedRefreshState {
966            counting: counting.clone(),
967            delay: refresh_delay,
968        };
969        let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
970        let dir = tempfile::tempdir().unwrap();
971        let strategy = Arc::new(auto_refresh_with_token(
972            &dir,
973            &base_url,
974            make_token("expired-token", 0, true),
975        ));
976
977        let start = Instant::now();
978        let mut handles = Vec::with_capacity(CONCURRENCY);
979        for _ in 0..CONCURRENCY {
980            let s = Arc::clone(&strategy);
981            handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
982        }
983
984        let results: Vec<_> = {
985            let mut results = Vec::with_capacity(handles.len());
986            for handle in handles {
987                results.push(handle.await.unwrap());
988            }
989            results
990        };
991        let elapsed = start.elapsed();
992
993        for token in &results {
994            assert_eq!(token.as_str(), "refreshed-token");
995        }
996
997        assert!(
998            elapsed < refresh_delay + Duration::from_millis(200),
999            "expected < {:?} for blocked callers, got {:?}",
1000            refresh_delay + Duration::from_millis(200),
1001            elapsed
1002        );
1003
1004        assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1005        assert_eq!(stats.total(), 1, "total refresh requests");
1006    }
1007
1008    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1009    async fn test_concurrent_expired_token_refresh_failure_recovers() {
1010        let counting = CountingState::new();
1011        let state = DelayedRefreshState {
1012            counting: counting.clone(),
1013            delay: Duration::from_millis(10),
1014        };
1015        let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1016        let dir = tempfile::tempdir().unwrap();
1017        let strategy = Arc::new(auto_refresh_with_token(
1018            &dir,
1019            &base_url,
1020            make_token("expired-token", 0, true),
1021        ));
1022
1023        let mut handles = Vec::with_capacity(CONCURRENCY);
1024        for _ in 0..CONCURRENCY {
1025            let s = Arc::clone(&strategy);
1026            handles.push(tokio::spawn(async move { s.get_token().await }));
1027        }
1028
1029        let results: Vec<_> = {
1030            let mut results = Vec::with_capacity(handles.len());
1031            for handle in handles {
1032                results.push(handle.await.unwrap());
1033            }
1034            results
1035        };
1036
1037        for result in &results {
1038            assert!(result.is_err(), "expected Expired error, got Ok");
1039            assert!(matches!(
1040                result.as_ref().unwrap_err(),
1041                AutoRefreshError::Expired
1042            ));
1043        }
1044
1045        let state = strategy.state.lock().await;
1046        assert!(
1047            state.token.as_ref().unwrap().refresh_token().is_some(),
1048            "refresh token should be restored after failed refresh"
1049        );
1050        drop(state);
1051
1052        assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1053        assert!(
1054            stats.total() >= 1,
1055            "at least one refresh attempt should be made"
1056        );
1057    }
1058
1059    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1060    async fn test_concurrent_refresh_failure_then_retry() {
1061        // Phase 1: Server returns errors.
1062        let counting1 = CountingState::new();
1063        let state1 = DelayedRefreshState {
1064            counting: counting1.clone(),
1065            delay: Duration::from_millis(50),
1066        };
1067        let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1068        let dir = tempfile::tempdir().unwrap();
1069        let strategy = Arc::new(auto_refresh_with_token(
1070            &dir,
1071            &base_url,
1072            make_token("expired-token", 0, true),
1073        ));
1074
1075        let mut handles = Vec::with_capacity(CONCURRENCY);
1076        for _ in 0..CONCURRENCY {
1077            let s = Arc::clone(&strategy);
1078            handles.push(tokio::spawn(async move { s.get_token().await }));
1079        }
1080
1081        let results: Vec<_> = {
1082            let mut results = Vec::with_capacity(handles.len());
1083            for handle in handles {
1084                results.push(handle.await.unwrap());
1085            }
1086            results
1087        };
1088
1089        for result in &results {
1090            assert!(
1091                result.is_err(),
1092                "first wave: expected Expired, got Ok({})",
1093                result.as_ref().unwrap().as_str()
1094            );
1095        }
1096
1097        // Phase 2: New server that returns success.
1098        let counting2 = CountingState::new();
1099        let state2 = DelayedRefreshState {
1100            counting: counting2.clone(),
1101            delay: Duration::from_millis(50),
1102        };
1103        let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1104
1105        let strategy2 = Arc::new(auto_refresh_with_token(
1106            &dir,
1107            &base_url2,
1108            make_token("expired-token", 0, true),
1109        ));
1110
1111        let mut handles = Vec::with_capacity(CONCURRENCY);
1112        for _ in 0..CONCURRENCY {
1113            let s = Arc::clone(&strategy2);
1114            handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1115        }
1116
1117        let results: Vec<_> = {
1118            let mut results = Vec::with_capacity(handles.len());
1119            for handle in handles {
1120                results.push(handle.await.unwrap());
1121            }
1122            results
1123        };
1124
1125        for token in &results {
1126            assert_eq!(token.as_str(), "refreshed-token");
1127        }
1128
1129        assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1130    }
1131}