Skip to main content

stack_auth/
auto_refresh.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2
3use tokio::sync::{Mutex, MutexGuard, Notify};
4
5use crate::refresher::Refresher;
6use crate::{ServiceToken, Token};
7
8/// Internal errors from [`AutoRefresh::get_token`].
9///
10/// Strategy wrappers convert these into [`AuthError`](crate::AuthError) for the
11/// public API.
12#[derive(Debug, thiserror::Error)]
13pub(crate) enum AutoRefreshError {
14    /// No token is cached and the strategy cannot self-authenticate.
15    #[error("No token found")]
16    NotFound,
17    /// The token has expired and refresh failed or is unavailable.
18    #[error("Token has expired")]
19    Expired,
20    /// The refresh/auth HTTP call failed.
21    #[error("Auth error: {0}")]
22    Auth(#[from] crate::AuthError),
23}
24
25impl From<AutoRefreshError> for crate::AuthError {
26    fn from(err: AutoRefreshError) -> Self {
27        match err {
28            AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
29            AutoRefreshError::Expired => crate::AuthError::TokenExpired,
30            AutoRefreshError::Auth(e) => e,
31        }
32    }
33}
34
35/// Caches a token in memory and uses a [`Refresher`] to re-authenticate
36/// or refresh before expiry.
37///
38/// See the [crate-level documentation](crate#token-refresh) for a full
39/// description of the concurrency model and flow diagram.
40pub(crate) struct AutoRefresh<R> {
41    refresher: R,
42    state: Mutex<State>,
43    /// Set to `true` while a refresh HTTP call is in-flight.
44    ///
45    /// Stored as an [`AtomicBool`] rather than inside [`State`] so that
46    /// [`CancelGuard`] can reset it on future cancellation without acquiring
47    /// the mutex.
48    refresh_in_progress: AtomicBool,
49    refresh_notify: Notify,
50}
51
52struct State {
53    token: Option<Token>,
54}
55
56/// Ensures [`AutoRefresh::refresh_in_progress`] is cleared and waiters are
57/// notified if the refresh future is cancelled (dropped) before completing.
58///
59/// On the normal path (success or handled error), the guard is defused before
60/// drop so that the regular cleanup code runs instead.
61struct CancelGuard<'a> {
62    in_progress: &'a AtomicBool,
63    notify: &'a Notify,
64    defused: bool,
65}
66
67impl Drop for CancelGuard<'_> {
68    fn drop(&mut self) {
69        if !self.defused {
70            self.in_progress.store(false, Ordering::Release);
71            self.notify.notify_waiters();
72        }
73    }
74}
75
76impl CancelGuard<'_> {
77    fn defuse(&mut self) {
78        self.defused = true;
79    }
80}
81
82impl State {
83    fn service_token(&self) -> Result<ServiceToken, AutoRefreshError> {
84        let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
85        Ok(ServiceToken::new(token.access_token().clone()))
86    }
87
88    fn require_usable_token(&self) -> Result<ServiceToken, AutoRefreshError> {
89        let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
90        if token.is_usable() {
91            Ok(ServiceToken::new(token.access_token().clone()))
92        } else {
93            Err(AutoRefreshError::Expired)
94        }
95    }
96}
97
98impl<R> AutoRefresh<R> {
99    /// Create a new `AutoRefresh` with no initial token.
100    ///
101    /// The first call to `get_token` will attempt initial authentication via
102    /// `try_credential(None)` → `refresh()`. Use this for refreshers that can
103    /// self-authenticate (e.g. access keys).
104    pub(crate) fn new(refresher: R) -> Self {
105        Self {
106            refresher,
107            state: Mutex::new(State { token: None }),
108            refresh_in_progress: AtomicBool::new(false),
109            refresh_notify: Notify::new(),
110        }
111    }
112
113    /// Create a new `AutoRefresh` with a pre-loaded token.
114    ///
115    /// Use this for refreshers that cannot self-authenticate (e.g. OAuth,
116    /// which needs a refresh token from a prior device code flow).
117    pub(crate) fn with_token(refresher: R, token: Token) -> Self {
118        Self {
119            refresher,
120            state: Mutex::new(State { token: Some(token) }),
121            refresh_in_progress: AtomicBool::new(false),
122            refresh_notify: Notify::new(),
123        }
124    }
125}
126
127impl<R: Refresher> AutoRefresh<R> {
128    /// Retrieve a valid access token, refreshing or re-authenticating as needed.
129    pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
130        let mut state = self.state.lock().await;
131
132        if state.token.is_none() {
133            return self.initial_auth(&mut state).await;
134        }
135
136        if !state.token.as_ref().is_some_and(|t| t.is_expired()) {
137            return state.service_token();
138        }
139
140        if self.refresh_in_progress.load(Ordering::Acquire) {
141            return self.wait_for_in_flight_refresh(state).await;
142        }
143
144        let Some(credential) = self.refresher.try_credential(state.token.as_mut()) else {
145            return state.require_usable_token();
146        };
147
148        self.refresh_in_progress.store(true, Ordering::Release);
149
150        if state.token.as_ref().is_some_and(|t| t.is_usable()) {
151            self.refresh_non_blocking(state, credential).await
152        } else {
153            self.refresh_blocking(&mut state, credential).await
154        }
155    }
156
157    /// No cached token — authenticate via `try_credential(None)`.
158    ///
159    /// The lock is held throughout to prevent concurrent initial-auth attempts.
160    async fn initial_auth(&self, state: &mut State) -> Result<ServiceToken, AutoRefreshError> {
161        let Some(credential) = self.refresher.try_credential(None) else {
162            return Err(AutoRefreshError::NotFound);
163        };
164        self.refresh_in_progress.store(true, Ordering::Release);
165        let mut guard = CancelGuard {
166            in_progress: &self.refresh_in_progress,
167            notify: &self.refresh_notify,
168            defused: false,
169        };
170        match self.refresher.refresh(&credential).await {
171            Ok(new_token) => {
172                guard.defuse();
173                self.refresher.save(&new_token);
174                let service_token = ServiceToken::new(new_token.access_token().clone());
175                state.token = Some(new_token);
176                self.refresh_in_progress.store(false, Ordering::Release);
177                Ok(service_token)
178            }
179            Err(err) => {
180                guard.defuse();
181                self.refresh_in_progress.store(false, Ordering::Release);
182                Err(AutoRefreshError::Auth(err))
183            }
184        }
185    }
186
187    /// Another caller is already refreshing — return the current token if still
188    /// usable, otherwise wait for the in-flight refresh to complete via `Notify`.
189    ///
190    /// Takes `MutexGuard` by value because the lock is dropped before awaiting
191    /// the notification.
192    async fn wait_for_in_flight_refresh(
193        &self,
194        state: MutexGuard<'_, State>,
195    ) -> Result<ServiceToken, AutoRefreshError> {
196        if let Ok(token) = state.service_token() {
197            if state.token.as_ref().is_some_and(|t| t.is_usable()) {
198                return Ok(token);
199            }
200        }
201        // Token crossed real expiry during in-flight refresh. Wait for the
202        // refresh to complete rather than returning Expired.
203        let notified = self.refresh_notify.notified();
204        drop(state);
205        notified.await;
206        // Re-check after wake — refresh may have failed.
207        let state = self.state.lock().await;
208        state.require_usable_token()
209    }
210
211    /// Token is expiring but still usable — drop the lock, refresh in the
212    /// background of this call, and return the old (still-valid) token.
213    ///
214    /// Takes `MutexGuard` by value because the lock is dropped before the HTTP
215    /// request. Notifies waiters after the refresh completes (success or error).
216    ///
217    /// A [`CancelGuard`] ensures that if this future is cancelled during the
218    /// HTTP request, `refresh_in_progress` is cleared, the credential is
219    /// restored (best-effort via `try_lock`), and waiters are notified.
220    async fn refresh_non_blocking(
221        &self,
222        state: MutexGuard<'_, State>,
223        credential: R::Credential,
224    ) -> Result<ServiceToken, AutoRefreshError> {
225        let current_service_token = state.service_token()?;
226        drop(state);
227
228        let mut guard = CancelGuard {
229            in_progress: &self.refresh_in_progress,
230            notify: &self.refresh_notify,
231            defused: false,
232        };
233
234        match self.refresher.refresh(&credential).await {
235            Ok(new_token) => {
236                guard.defuse();
237                self.refresher.save(&new_token);
238                let mut state = self.state.lock().await;
239                state.token = Some(new_token);
240                self.refresh_in_progress.store(false, Ordering::Release);
241            }
242            Err(err) => {
243                guard.defuse();
244                tracing::warn!(%err, "token refresh failed (token still usable)");
245                let mut state = self.state.lock().await;
246                if let Some(token) = state.token.as_mut() {
247                    self.refresher.restore(token, credential);
248                }
249                self.refresh_in_progress.store(false, Ordering::Release);
250            }
251        }
252
253        self.refresh_notify.notify_waiters();
254        Ok(current_service_token)
255    }
256
257    /// Token is fully expired — refresh while holding the lock so concurrent
258    /// callers block on `lock().await` until the new token is available.
259    ///
260    /// A [`CancelGuard`] ensures that if this future is cancelled during the
261    /// HTTP request, `refresh_in_progress` is cleared and waiters are notified
262    /// so they don't hang indefinitely. (The credential is lost on cancel —
263    /// see [`CancelGuard`] docs — but subsequent callers will get `Expired`
264    /// rather than blocking forever.)
265    async fn refresh_blocking(
266        &self,
267        state: &mut State,
268        credential: R::Credential,
269    ) -> Result<ServiceToken, AutoRefreshError> {
270        let mut guard = CancelGuard {
271            in_progress: &self.refresh_in_progress,
272            notify: &self.refresh_notify,
273            defused: false,
274        };
275        match self.refresher.refresh(&credential).await {
276            Ok(new_token) => {
277                guard.defuse();
278                self.refresher.save(&new_token);
279                let service_token = ServiceToken::new(new_token.access_token().clone());
280                state.token = Some(new_token);
281                self.refresh_in_progress.store(false, Ordering::Release);
282                Ok(service_token)
283            }
284            Err(err) => {
285                guard.defuse();
286                tracing::warn!(%err, "token refresh failed");
287                if let Some(token) = state.token.as_mut() {
288                    self.refresher.restore(token, credential);
289                }
290                self.refresh_in_progress.store(false, Ordering::Release);
291                Err(AutoRefreshError::Expired)
292            }
293        }
294    }
295}
296
297#[cfg(test)]
298#[allow(clippy::unwrap_used)]
299mod tests {
300    use super::*;
301    use crate::oauth_refresher::OAuthRefresher;
302    use crate::SecretToken;
303    use mocktail::prelude::*;
304    use stack_profile::ProfileStore;
305    use std::sync::Arc;
306    use std::time::{SystemTime, UNIX_EPOCH};
307
308    fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
309        let now = SystemTime::now()
310            .duration_since(UNIX_EPOCH)
311            .unwrap()
312            .as_secs();
313
314        Token {
315            access_token: SecretToken::new(access),
316            token_type: "Bearer".to_string(),
317            expires_at: now + expires_in,
318            refresh_token: if refresh {
319                Some(SecretToken::new("test-refresh-token"))
320            } else {
321                None
322            },
323            region: None,
324            client_id: None,
325            device_instance_id: None,
326        }
327    }
328
329    fn refresh_response_json(access: &str) -> serde_json::Value {
330        serde_json::json!({
331            "access_token": access,
332            "token_type": "Bearer",
333            "expires_in": 3600,
334            "refresh_token": "new-refresh-token"
335        })
336    }
337
338    fn error_json(error: &str) -> serde_json::Value {
339        serde_json::json!({
340            "error": error,
341            "error_description": format!("{error} occurred")
342        })
343    }
344
345    async fn start_server(mocks: MockSet) -> MockServer {
346        let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
347        server.start().await.unwrap();
348        server
349    }
350
351    fn auto_refresh_with_token(
352        dir: &tempfile::TempDir,
353        server: &MockServer,
354        token: Token,
355    ) -> AutoRefresh<OAuthRefresher> {
356        let store = ProfileStore::new(dir.path());
357        store.save_profile(&token).unwrap();
358        let refresher = OAuthRefresher::new(
359            Some(store),
360            server.url(""),
361            "cli",
362            "ap-southeast-2.aws",
363            None,
364        );
365        AutoRefresh::with_token(refresher, token)
366    }
367
368    mod given_no_cached_token {
369        use super::*;
370
371        #[tokio::test]
372        async fn returns_not_found_for_oauth() {
373            let server = start_server(MockSet::new()).await;
374            let store = ProfileStore::new("/tmp/nonexistent");
375            let refresher = OAuthRefresher::new(
376                Some(store),
377                server.url(""),
378                "cli",
379                "ap-southeast-2.aws",
380                None,
381            );
382            let strategy = AutoRefresh::new(refresher);
383
384            let err = strategy.get_token().await.unwrap_err();
385
386            assert!(
387                matches!(err, AutoRefreshError::NotFound),
388                "expected NotFound, got: {err:?}"
389            );
390        }
391    }
392
393    mod given_fresh_token {
394        use super::*;
395
396        #[tokio::test]
397        async fn returns_cached_token() {
398            let dir = tempfile::tempdir().unwrap();
399            let server = start_server(MockSet::new()).await;
400            let strategy =
401                auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
402
403            let token = strategy.get_token().await.unwrap();
404
405            assert_eq!(
406                token.as_str(),
407                "my-access-token",
408                "should return the cached access token"
409            );
410        }
411
412        #[tokio::test]
413        async fn caches_across_calls() {
414            let dir = tempfile::tempdir().unwrap();
415            let server = start_server(MockSet::new()).await;
416            let strategy =
417                auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
418
419            let token1 = strategy.get_token().await.unwrap();
420            assert_eq!(
421                token1.as_str(),
422                "my-access-token",
423                "first call should return the cached token"
424            );
425
426            // Delete the file — second call should still return the cached token.
427            std::fs::remove_file(dir.path().join("auth.json")).unwrap();
428
429            let token2 = strategy.get_token().await.unwrap();
430            assert_eq!(
431                token2.as_str(),
432                "my-access-token",
433                "second call should return the cached token even after file deletion"
434            );
435        }
436
437        #[tokio::test]
438        async fn does_not_trigger_refresh() {
439            // Mock that would fail if hit — proves no refresh request is made.
440            let mut mocks = MockSet::new();
441            mocks.mock(|when, then| {
442                when.post().path("/oauth/token");
443                then.internal_server_error()
444                    .json(error_json("should_not_be_called"));
445            });
446            let server = start_server(mocks).await;
447            let dir = tempfile::tempdir().unwrap();
448            let strategy =
449                auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
450
451            let token = strategy.get_token().await.unwrap();
452
453            assert_eq!(
454                token.as_str(),
455                "fresh-token",
456                "should return fresh token without triggering refresh"
457            );
458        }
459    }
460
461    mod given_fully_expired_token {
462        use super::*;
463
464        mod without_refresh_token {
465            use super::*;
466
467            #[tokio::test]
468            async fn returns_expired() {
469                let dir = tempfile::tempdir().unwrap();
470                let server = start_server(MockSet::new()).await;
471                let strategy =
472                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
473
474                let err = strategy.get_token().await.unwrap_err();
475
476                assert!(
477                    matches!(err, AutoRefreshError::Expired),
478                    "expected Expired, got: {err:?}"
479                );
480            }
481        }
482
483        mod with_refresh_token {
484            use super::*;
485
486            #[tokio::test]
487            async fn refreshes_and_returns_new_token() {
488                let mut mocks = MockSet::new();
489                mocks.mock(|when, then| {
490                    when.post().path("/oauth/token");
491                    then.json(refresh_response_json("refreshed-token"));
492                });
493                let server = start_server(mocks).await;
494                let dir = tempfile::tempdir().unwrap();
495                let strategy =
496                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
497
498                let token = strategy.get_token().await.unwrap();
499
500                assert_eq!(
501                    token.as_str(),
502                    "refreshed-token",
503                    "should return the refreshed token"
504                );
505            }
506
507            #[tokio::test]
508            async fn persists_refreshed_token_to_disk() {
509                let mut mocks = MockSet::new();
510                mocks.mock(|when, then| {
511                    when.post().path("/oauth/token");
512                    then.json(refresh_response_json("refreshed-token"));
513                });
514                let server = start_server(mocks).await;
515                let dir = tempfile::tempdir().unwrap();
516                let strategy =
517                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
518
519                let _ = strategy.get_token().await.unwrap();
520
521                // Verify the refreshed token was saved to disk.
522                let store = ProfileStore::new(dir.path());
523                let on_disk: Token = store.load_profile().unwrap();
524                assert_eq!(
525                    on_disk.access_token().as_str(),
526                    "refreshed-token",
527                    "refreshed token should be persisted to disk"
528                );
529            }
530
531            #[tokio::test]
532            async fn returns_expired_on_refresh_failure() {
533                let mut mocks = MockSet::new();
534                mocks.mock(|when, then| {
535                    when.post().path("/oauth/token");
536                    then.bad_request().json(error_json("invalid_grant"));
537                });
538                let server = start_server(mocks).await;
539                let dir = tempfile::tempdir().unwrap();
540                let strategy =
541                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
542
543                let err = strategy.get_token().await.unwrap_err();
544
545                assert!(
546                    matches!(err, AutoRefreshError::Expired),
547                    "expected Expired after failed refresh, got: {err:?}"
548                );
549            }
550
551            #[tokio::test]
552            async fn restores_refresh_token_after_failure() {
553                let mut mocks = MockSet::new();
554                mocks.mock(|when, then| {
555                    when.post().path("/oauth/token");
556                    then.bad_request().json(error_json("invalid_grant"));
557                });
558                let server = start_server(mocks).await;
559                let dir = tempfile::tempdir().unwrap();
560                let strategy =
561                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
562
563                // First call: refresh fails, returns Expired.
564                let err = strategy.get_token().await.unwrap_err();
565                assert!(
566                    matches!(err, AutoRefreshError::Expired),
567                    "expected Expired on first attempt, got: {err:?}"
568                );
569
570                // Verify the refresh token was restored so a retry is possible.
571                let state = strategy.state.lock().await;
572                assert!(
573                    state.token.is_some(),
574                    "token should still be cached after failed refresh"
575                );
576                assert!(
577                    state.token.as_ref().unwrap().refresh_token().is_some(),
578                    "refresh token should be restored for retry"
579                );
580                drop(state);
581
582                // Replace mock with a success response.
583                server.mocks().clear();
584                server.mocks().mock(|when, then| {
585                    when.post().path("/oauth/token");
586                    then.json(refresh_response_json("refreshed-token"));
587                });
588
589                // Second call: refresh token is available → retry succeeds.
590                let token = strategy.get_token().await.unwrap();
591                assert_eq!(
592                    token.as_str(),
593                    "refreshed-token",
594                    "retry should succeed with restored refresh token"
595                );
596            }
597
598            #[tokio::test]
599            async fn sequential_calls_only_refresh_once() {
600                let mut mocks = MockSet::new();
601                mocks.mock(|when, then| {
602                    when.post().path("/oauth/token");
603                    then.json(refresh_response_json("refreshed-once"));
604                });
605                let server = start_server(mocks).await;
606                let dir = tempfile::tempdir().unwrap();
607                let strategy =
608                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
609
610                // First call triggers refresh.
611                let token = strategy.get_token().await.unwrap();
612                assert_eq!(
613                    token.as_str(),
614                    "refreshed-once",
615                    "first call should trigger refresh"
616                );
617
618                // Swap mock to track if another refresh is attempted.
619                server.mocks().clear();
620                server.mocks().mock(|when, then| {
621                    when.post().path("/oauth/token");
622                    then.json(refresh_response_json("refreshed-twice"));
623                });
624
625                // Calls 2-5: the refreshed token is fresh, so no further refresh.
626                for _ in 0..4 {
627                    let token = strategy.get_token().await.unwrap();
628                    assert_eq!(
629                        token.as_str(),
630                        "refreshed-once",
631                        "should return cached refreshed token, not trigger another refresh"
632                    );
633                }
634            }
635
636            #[tokio::test]
637            async fn prevents_second_refresh_after_success() {
638                let mut mocks = MockSet::new();
639                mocks.mock(|when, then| {
640                    when.post().path("/oauth/token");
641                    then.json(refresh_response_json("refreshed-token"));
642                });
643                let server = start_server(mocks).await;
644                let dir = tempfile::tempdir().unwrap();
645                let strategy =
646                    auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
647
648                // First call refreshes successfully.
649                let token = strategy.get_token().await.unwrap();
650                assert_eq!(
651                    token.as_str(),
652                    "refreshed-token",
653                    "first call should refresh the token"
654                );
655
656                // Replace the mock with one that errors.
657                server.mocks().clear();
658                server.mocks().mock(|when, then| {
659                    when.post().path("/oauth/token");
660                    then.bad_request().json(error_json("should_not_be_called"));
661                });
662
663                // Second call should return the refreshed token without hitting
664                // the server again (the new token has a fresh expiry).
665                let token = strategy.get_token().await.unwrap();
666                assert_eq!(
667                    token.as_str(),
668                    "refreshed-token",
669                    "second call should return cached refreshed token"
670                );
671            }
672        }
673    }
674
675    mod given_expiring_but_usable_token {
676        use super::*;
677
678        mod when_refresh_fails {
679            use super::*;
680
681            #[tokio::test]
682            async fn returns_current_token() {
683                let mut mocks = MockSet::new();
684                mocks.mock(|when, then| {
685                    when.post().path("/oauth/token");
686                    then.bad_request().json(error_json("server_error"));
687                });
688                let server = start_server(mocks).await;
689                let dir = tempfile::tempdir().unwrap();
690                // Token expires in 30s (within the 90s leeway so is_expired() = true),
691                // but the access token is still technically usable.
692                let strategy =
693                    auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
694
695                // The refresh fails, but the access token should still be returned
696                // because it's still usable (30s remaining > 0).
697                let token = strategy.get_token().await.unwrap();
698                assert_eq!(
699                    token.as_str(),
700                    "still-usable",
701                    "should return still-usable token despite failed refresh"
702                );
703
704                // Verify the access token and refresh token are still present.
705                let state = strategy.state.lock().await;
706                assert!(state.token.is_some(), "token should still be cached");
707                assert_eq!(
708                    state.token.as_ref().unwrap().access_token().as_str(),
709                    "still-usable",
710                    "access token should be unchanged after failed refresh"
711                );
712                assert!(
713                    state.token.as_ref().unwrap().refresh_token().is_some(),
714                    "refresh token should be restored after failed refresh"
715                );
716            }
717
718            #[tokio::test]
719            async fn restores_refresh_token_for_retry() {
720                let mut mocks = MockSet::new();
721                mocks.mock(|when, then| {
722                    when.post().path("/oauth/token");
723                    then.bad_request().json(error_json("server_error"));
724                });
725                let server = start_server(mocks).await;
726                let dir = tempfile::tempdir().unwrap();
727                // Token expires in 30s — is_expired() = true, is_usable() = true.
728                let strategy =
729                    auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
730
731                // First call: refresh fails, but the still-usable token is returned.
732                let token = strategy.get_token().await.unwrap();
733                assert_eq!(
734                    token.as_str(),
735                    "still-usable",
736                    "first call should return still-usable token"
737                );
738
739                // Replace mock with a success response.
740                server.mocks().clear();
741                server.mocks().mock(|when, then| {
742                    when.post().path("/oauth/token");
743                    then.json(refresh_response_json("refreshed-token"));
744                });
745
746                // Second call: refresh token was restored, so the retry succeeds.
747                let token = strategy.get_token().await.unwrap();
748                assert!(
749                    token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
750                    "expected old or refreshed token, got: {}",
751                    token.as_str()
752                );
753
754                // Verify the cache now holds the refreshed token.
755                let state = strategy.state.lock().await;
756                assert_eq!(
757                    state.token.as_ref().unwrap().access_token().as_str(),
758                    "refreshed-token",
759                    "cache should hold the refreshed token after retry"
760                );
761            }
762        }
763    }
764
765    mod given_concurrent_callers {
766        use super::*;
767
768        #[tokio::test]
769        async fn returns_usable_token_while_refreshing() {
770            let mut mocks = MockSet::new();
771            mocks.mock(|when, then| {
772                when.post().path("/oauth/token");
773                then.json(refresh_response_json("refreshed-token"));
774            });
775            let server = start_server(mocks).await;
776            let dir = tempfile::tempdir().unwrap();
777            let strategy = Arc::new(auto_refresh_with_token(
778                &dir,
779                &server,
780                make_token("still-usable", 30, true),
781            ));
782
783            let s1 = Arc::clone(&strategy);
784            let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
785
786            let s2 = Arc::clone(&strategy);
787            let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
788
789            let (result_a, result_b) = tokio::join!(handle_a, handle_b);
790            let token_a = result_a.unwrap();
791            let token_b = result_b.unwrap();
792
793            assert!(
794                token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
795                "unexpected token_a: {}",
796                token_a.as_str()
797            );
798            assert!(
799                token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
800                "unexpected token_b: {}",
801                token_b.as_str()
802            );
803        }
804
805        #[tokio::test]
806        async fn blocks_until_refresh_completes() {
807            let mut mocks = MockSet::new();
808            mocks.mock(|when, then| {
809                when.post().path("/oauth/token");
810                then.json(refresh_response_json("refreshed-token"));
811            });
812            let server = start_server(mocks).await;
813            let dir = tempfile::tempdir().unwrap();
814            let strategy = Arc::new(auto_refresh_with_token(
815                &dir,
816                &server,
817                make_token("expired-token", 0, true),
818            ));
819
820            let s1 = Arc::clone(&strategy);
821            let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
822
823            let s2 = Arc::clone(&strategy);
824            let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
825
826            let (result_a, result_b) = tokio::join!(handle_a, handle_b);
827            let token_a = result_a.unwrap();
828            let token_b = result_b.unwrap();
829
830            assert_eq!(
831                token_a.as_str(),
832                "refreshed-token",
833                "caller a should receive refreshed token"
834            );
835            assert_eq!(
836                token_b.as_str(),
837                "refreshed-token",
838                "caller b should receive refreshed token"
839            );
840        }
841    }
842}
843
844#[cfg(test)]
845#[allow(clippy::unwrap_used)]
846mod stress_tests {
847    use super::*;
848    use crate::oauth_refresher::OAuthRefresher;
849    use crate::SecretToken;
850    use stack_profile::ProfileStore;
851    use std::sync::atomic::{AtomicUsize, Ordering};
852    use std::sync::Arc;
853    use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
854
855    /// Tracks in-flight and peak concurrency for test assertions.
856    #[derive(Clone)]
857    struct CountingState {
858        total: Arc<AtomicUsize>,
859        current: Arc<AtomicUsize>,
860        peak: Arc<AtomicUsize>,
861    }
862
863    impl CountingState {
864        fn new() -> Self {
865            Self {
866                total: Arc::new(AtomicUsize::new(0)),
867                current: Arc::new(AtomicUsize::new(0)),
868                peak: Arc::new(AtomicUsize::new(0)),
869            }
870        }
871
872        fn enter(&self) {
873            self.total.fetch_add(1, Ordering::SeqCst);
874            let prev = self.current.fetch_add(1, Ordering::SeqCst);
875            self.peak.fetch_max(prev + 1, Ordering::SeqCst);
876        }
877
878        fn exit(&self) {
879            self.current.fetch_sub(1, Ordering::SeqCst);
880        }
881
882        fn peak(&self) -> usize {
883            self.peak.load(Ordering::SeqCst)
884        }
885
886        fn total(&self) -> usize {
887            self.total.load(Ordering::SeqCst)
888        }
889    }
890
891    #[derive(Clone)]
892    struct DelayedRefreshState {
893        counting: CountingState,
894        delay: Duration,
895    }
896
897    async fn delayed_refresh_handler(
898        axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
899    ) -> axum::Json<serde_json::Value> {
900        state.counting.enter();
901        tokio::time::sleep(state.delay).await;
902        state.counting.exit();
903        axum::Json(serde_json::json!({
904            "access_token": "refreshed-token",
905            "token_type": "Bearer",
906            "expires_in": 3600,
907            "refresh_token": "new-refresh-token"
908        }))
909    }
910
911    async fn delayed_error_handler(
912        axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
913    ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
914        state.counting.enter();
915        tokio::time::sleep(state.delay).await;
916        state.counting.exit();
917        (
918            axum::http::StatusCode::BAD_REQUEST,
919            axum::Json(serde_json::json!({
920                "error": "invalid_grant",
921                "error_description": "invalid_grant occurred"
922            })),
923        )
924    }
925
926    async fn start_axum_server<H, T>(
927        handler: H,
928        state: DelayedRefreshState,
929    ) -> (url::Url, CountingState)
930    where
931        H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
932        T: 'static,
933    {
934        let counting = state.counting.clone();
935        let app = axum::Router::new()
936            .route("/oauth/token", axum::routing::post(handler))
937            .with_state(state);
938        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
939        let addr = listener.local_addr().unwrap();
940        tokio::spawn(async move {
941            axum::serve(listener, app).await.unwrap();
942        });
943        let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
944        (base_url, counting)
945    }
946
947    fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
948        let now = SystemTime::now()
949            .duration_since(UNIX_EPOCH)
950            .unwrap()
951            .as_secs();
952
953        Token {
954            access_token: SecretToken::new(access),
955            token_type: "Bearer".to_string(),
956            expires_at: now + expires_in,
957            refresh_token: if refresh {
958                Some(SecretToken::new("test-refresh-token"))
959            } else {
960                None
961            },
962            region: None,
963            client_id: None,
964            device_instance_id: None,
965        }
966    }
967
968    fn auto_refresh_with_token(
969        dir: &tempfile::TempDir,
970        base_url: &url::Url,
971        token: Token,
972    ) -> AutoRefresh<OAuthRefresher> {
973        let store = ProfileStore::new(dir.path());
974        store.save_profile(&token).unwrap();
975        let refresher = OAuthRefresher::new(
976            Some(store),
977            base_url.clone(),
978            "cli",
979            "ap-southeast-2.aws",
980            None,
981        );
982        AutoRefresh::with_token(refresher, token)
983    }
984
985    const CONCURRENCY: usize = 50;
986
987    mod given_fresh_token {
988        use super::*;
989
990        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
991        async fn all_callers_return_immediately() {
992            let counting = CountingState::new();
993            let state = DelayedRefreshState {
994                counting: counting.clone(),
995                delay: Duration::from_millis(500),
996            };
997            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
998            let dir = tempfile::tempdir().unwrap();
999            let strategy = Arc::new(auto_refresh_with_token(
1000                &dir,
1001                &base_url,
1002                make_token("fresh-token", 3600, true),
1003            ));
1004
1005            let start = Instant::now();
1006            let mut handles = Vec::with_capacity(CONCURRENCY);
1007            for _ in 0..CONCURRENCY {
1008                let s = Arc::clone(&strategy);
1009                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1010            }
1011
1012            let results: Vec<_> = {
1013                let mut results = Vec::with_capacity(handles.len());
1014                for handle in handles {
1015                    results.push(handle.await.unwrap());
1016                }
1017                results
1018            };
1019            let elapsed = start.elapsed();
1020
1021            for token in &results {
1022                assert_eq!(
1023                    token.as_str(),
1024                    "fresh-token",
1025                    "all callers should receive the fresh token"
1026                );
1027            }
1028
1029            assert!(
1030                elapsed < Duration::from_millis(200),
1031                "expected < 200ms for fresh tokens, got {:?}",
1032                elapsed
1033            );
1034            assert_eq!(stats.total(), 0, "no refresh requests should be made");
1035        }
1036    }
1037
1038    mod given_expiring_but_usable_token {
1039        use super::*;
1040
1041        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1042        async fn non_blocking_reads_during_refresh() {
1043            let counting = CountingState::new();
1044            let state = DelayedRefreshState {
1045                counting: counting.clone(),
1046                delay: Duration::from_millis(500),
1047            };
1048            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1049            let dir = tempfile::tempdir().unwrap();
1050            let strategy = Arc::new(auto_refresh_with_token(
1051                &dir,
1052                &base_url,
1053                make_token("still-usable", 30, true),
1054            ));
1055
1056            let start = Instant::now();
1057            let mut handles = Vec::with_capacity(CONCURRENCY);
1058            for _ in 0..CONCURRENCY {
1059                let s = Arc::clone(&strategy);
1060                handles.push(tokio::spawn(async move {
1061                    let call_start = Instant::now();
1062                    let token = s.get_token().await.unwrap();
1063                    (token, call_start.elapsed())
1064                }));
1065            }
1066
1067            let results: Vec<_> = {
1068                let mut results = Vec::with_capacity(handles.len());
1069                for handle in handles {
1070                    results.push(handle.await.unwrap());
1071                }
1072                results
1073            };
1074            let elapsed = start.elapsed();
1075
1076            for (token, _) in &results {
1077                assert!(
1078                    token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
1079                    "unexpected token: {}",
1080                    token.as_str()
1081                );
1082            }
1083
1084            let fast_callers = results
1085                .iter()
1086                .filter(|(_, dur)| *dur < Duration::from_millis(100))
1087                .count();
1088            assert!(
1089                fast_callers >= CONCURRENCY - 1,
1090                "expected at least {} fast callers, got {} (total elapsed: {:?})",
1091                CONCURRENCY - 1,
1092                fast_callers,
1093                elapsed
1094            );
1095
1096            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1097            assert_eq!(stats.total(), 1, "total refresh requests");
1098        }
1099
1100        /// Reproduces the race condition where a token crosses real expiry during
1101        /// an in-flight non-blocking refresh. Before the fix, late-arriving callers
1102        /// would see `refresh_in_progress = true` + `!is_usable()` and return
1103        /// `Err(Expired)` instead of waiting for the refresh to complete.
1104        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1105        async fn waiters_receive_token_when_expiry_crosses() {
1106            // Token with 1s until real expiry (minimum granularity since
1107            // expires_at is in seconds). is_expired() = true (within 90s leeway),
1108            // is_usable() = true (1s remaining). Refresh takes 1.5s so the token
1109            // crosses real expiry mid-refresh.
1110            let refresh_delay = Duration::from_millis(1500);
1111            let counting = CountingState::new();
1112            let state = DelayedRefreshState {
1113                counting: counting.clone(),
1114                delay: refresh_delay,
1115            };
1116            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1117            let dir = tempfile::tempdir().unwrap();
1118            let strategy = Arc::new(auto_refresh_with_token(
1119                &dir,
1120                &base_url,
1121                make_token("expiring-soon", 1, true),
1122            ));
1123
1124            // First caller triggers the non-blocking refresh and gets the old token.
1125            let first = strategy.get_token().await.unwrap();
1126            assert_eq!(
1127                first.as_str(),
1128                "expiring-soon",
1129                "first caller should receive the expiring token"
1130            );
1131
1132            // Wait for the token to cross real expiry (but refresh is still in-flight).
1133            tokio::time::sleep(Duration::from_millis(1100)).await;
1134
1135            // Launch 50 concurrent callers. Without the fix, these would all get
1136            // Err(Expired) because refresh_in_progress = true and !is_usable().
1137            let mut handles = Vec::with_capacity(CONCURRENCY);
1138            for _ in 0..CONCURRENCY {
1139                let s = Arc::clone(&strategy);
1140                handles.push(tokio::spawn(async move { s.get_token().await }));
1141            }
1142
1143            let results: Vec<_> = {
1144                let mut results = Vec::with_capacity(handles.len());
1145                for handle in handles {
1146                    results.push(handle.await.unwrap());
1147                }
1148                results
1149            };
1150
1151            // All callers must succeed — none should get Expired.
1152            for (i, result) in results.iter().enumerate() {
1153                assert!(
1154                    result.is_ok(),
1155                    "caller {i} got Err({:?}), expected Ok",
1156                    result.as_ref().unwrap_err()
1157                );
1158                assert_eq!(
1159                    result.as_ref().unwrap().as_str(),
1160                    "refreshed-token",
1161                    "caller {i} should receive the refreshed token"
1162                );
1163            }
1164
1165            assert_eq!(stats.total(), 1, "only one refresh request should be made");
1166        }
1167    }
1168
1169    mod given_fully_expired_token {
1170        use super::*;
1171
1172        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1173        async fn all_callers_block_until_refresh() {
1174            let refresh_delay = Duration::from_millis(200);
1175            let counting = CountingState::new();
1176            let state = DelayedRefreshState {
1177                counting: counting.clone(),
1178                delay: refresh_delay,
1179            };
1180            let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1181            let dir = tempfile::tempdir().unwrap();
1182            let strategy = Arc::new(auto_refresh_with_token(
1183                &dir,
1184                &base_url,
1185                make_token("expired-token", 0, true),
1186            ));
1187
1188            let start = Instant::now();
1189            let mut handles = Vec::with_capacity(CONCURRENCY);
1190            for _ in 0..CONCURRENCY {
1191                let s = Arc::clone(&strategy);
1192                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1193            }
1194
1195            let results: Vec<_> = {
1196                let mut results = Vec::with_capacity(handles.len());
1197                for handle in handles {
1198                    results.push(handle.await.unwrap());
1199                }
1200                results
1201            };
1202            let elapsed = start.elapsed();
1203
1204            for token in &results {
1205                assert_eq!(
1206                    token.as_str(),
1207                    "refreshed-token",
1208                    "all callers should receive refreshed token"
1209                );
1210            }
1211
1212            assert!(
1213                elapsed < refresh_delay + Duration::from_millis(200),
1214                "expected < {:?} for blocked callers, got {:?}",
1215                refresh_delay + Duration::from_millis(200),
1216                elapsed
1217            );
1218
1219            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1220            assert_eq!(stats.total(), 1, "total refresh requests");
1221        }
1222
1223        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1224        async fn all_callers_receive_expired_on_failure() {
1225            let counting = CountingState::new();
1226            let state = DelayedRefreshState {
1227                counting: counting.clone(),
1228                delay: Duration::from_millis(10),
1229            };
1230            let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1231            let dir = tempfile::tempdir().unwrap();
1232            let strategy = Arc::new(auto_refresh_with_token(
1233                &dir,
1234                &base_url,
1235                make_token("expired-token", 0, true),
1236            ));
1237
1238            let mut handles = Vec::with_capacity(CONCURRENCY);
1239            for _ in 0..CONCURRENCY {
1240                let s = Arc::clone(&strategy);
1241                handles.push(tokio::spawn(async move { s.get_token().await }));
1242            }
1243
1244            let results: Vec<_> = {
1245                let mut results = Vec::with_capacity(handles.len());
1246                for handle in handles {
1247                    results.push(handle.await.unwrap());
1248                }
1249                results
1250            };
1251
1252            for result in &results {
1253                assert!(result.is_err(), "expected Expired error, got Ok");
1254                let err = result.as_ref().unwrap_err();
1255                assert!(
1256                    matches!(err, AutoRefreshError::Expired),
1257                    "expected Expired, got: {err:?}"
1258                );
1259            }
1260
1261            let state = strategy.state.lock().await;
1262            assert!(
1263                state.token.as_ref().unwrap().refresh_token().is_some(),
1264                "refresh token should be restored after failed refresh"
1265            );
1266            drop(state);
1267
1268            assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1269            assert!(
1270                stats.total() >= 1,
1271                "at least one refresh attempt should be made"
1272            );
1273        }
1274
1275        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1276        async fn retry_succeeds_after_failure() {
1277            // Phase 1: Server returns errors.
1278            let counting1 = CountingState::new();
1279            let state1 = DelayedRefreshState {
1280                counting: counting1.clone(),
1281                delay: Duration::from_millis(50),
1282            };
1283            let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1284            let dir = tempfile::tempdir().unwrap();
1285            let strategy = Arc::new(auto_refresh_with_token(
1286                &dir,
1287                &base_url,
1288                make_token("expired-token", 0, true),
1289            ));
1290
1291            let mut handles = Vec::with_capacity(CONCURRENCY);
1292            for _ in 0..CONCURRENCY {
1293                let s = Arc::clone(&strategy);
1294                handles.push(tokio::spawn(async move { s.get_token().await }));
1295            }
1296
1297            let results: Vec<_> = {
1298                let mut results = Vec::with_capacity(handles.len());
1299                for handle in handles {
1300                    results.push(handle.await.unwrap());
1301                }
1302                results
1303            };
1304
1305            for result in &results {
1306                assert!(
1307                    result.is_err(),
1308                    "first wave: expected Expired, got Ok({})",
1309                    result.as_ref().unwrap().as_str()
1310                );
1311            }
1312
1313            // Phase 2: New server that returns success.
1314            let counting2 = CountingState::new();
1315            let state2 = DelayedRefreshState {
1316                counting: counting2.clone(),
1317                delay: Duration::from_millis(50),
1318            };
1319            let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1320
1321            let strategy2 = Arc::new(auto_refresh_with_token(
1322                &dir,
1323                &base_url2,
1324                make_token("expired-token", 0, true),
1325            ));
1326
1327            let mut handles = Vec::with_capacity(CONCURRENCY);
1328            for _ in 0..CONCURRENCY {
1329                let s = Arc::clone(&strategy2);
1330                handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1331            }
1332
1333            let results: Vec<_> = {
1334                let mut results = Vec::with_capacity(handles.len());
1335                for handle in handles {
1336                    results.push(handle.await.unwrap());
1337                }
1338                results
1339            };
1340
1341            for token in &results {
1342                assert_eq!(
1343                    token.as_str(),
1344                    "refreshed-token",
1345                    "retry callers should receive refreshed token"
1346                );
1347            }
1348
1349            assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1350        }
1351    }
1352
1353    mod given_cancelled_refresh {
1354        use super::*;
1355
1356        /// If a blocking refresh (fully expired token) is cancelled mid-flight,
1357        /// the `CancelGuard` must reset `refresh_in_progress` and notify waiters
1358        /// so the next caller doesn't hang in `wait_for_in_flight_refresh`.
1359        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1360        async fn blocked_callers_recover_after_cancellation() {
1361            let counting = CountingState::new();
1362            let state = DelayedRefreshState {
1363                counting: counting.clone(),
1364                delay: Duration::from_secs(10), // Very slow — will be cancelled
1365            };
1366            let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1367            let dir = tempfile::tempdir().unwrap();
1368            let strategy = Arc::new(auto_refresh_with_token(
1369                &dir,
1370                &base_url,
1371                make_token("expired-token", 0, true),
1372            ));
1373
1374            // Spawn get_token and let the blocking refresh start.
1375            let s = Arc::clone(&strategy);
1376            let handle = tokio::spawn(async move { s.get_token().await });
1377            tokio::time::sleep(Duration::from_millis(100)).await;
1378
1379            // Cancel the refresh mid-flight.
1380            handle.abort();
1381            let _ = handle.await;
1382
1383            // The next caller must not hang. The credential is lost (refresh
1384            // token was taken before the HTTP call), so the result is Expired,
1385            // but the important thing is that it completes promptly.
1386            let s = Arc::clone(&strategy);
1387            let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1388
1389            assert!(
1390                result.is_ok(),
1391                "get_token() should not hang after cancelled blocking refresh"
1392            );
1393        }
1394
1395        /// If a non-blocking refresh (expiring-but-usable token) is cancelled
1396        /// mid-flight, the `CancelGuard` must reset `refresh_in_progress` and
1397        /// notify waiters so they don't hang once the token crosses real expiry.
1398        #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1399        async fn non_blocking_callers_recover_after_cancellation() {
1400            let counting = CountingState::new();
1401            let state = DelayedRefreshState {
1402                counting: counting.clone(),
1403                delay: Duration::from_secs(10), // Very slow — will be cancelled
1404            };
1405            let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1406            let dir = tempfile::tempdir().unwrap();
1407            // Token expires in 30s — is_expired() = true, is_usable() = true.
1408            let strategy = Arc::new(auto_refresh_with_token(
1409                &dir,
1410                &base_url,
1411                make_token("still-usable", 30, true),
1412            ));
1413
1414            // Spawn get_token — triggers non-blocking refresh, drops lock, then
1415            // blocks on the slow HTTP call.
1416            let s = Arc::clone(&strategy);
1417            let handle = tokio::spawn(async move { s.get_token().await });
1418            tokio::time::sleep(Duration::from_millis(100)).await;
1419
1420            // Cancel the refresh mid-flight.
1421            handle.abort();
1422            let _ = handle.await;
1423
1424            // The next caller must not hang. The token is still usable so it
1425            // should be returned even though the refresh was cancelled.
1426            let s = Arc::clone(&strategy);
1427            let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1428
1429            assert!(
1430                result.is_ok(),
1431                "get_token() should not hang after cancelled non-blocking refresh"
1432            );
1433            let result = result.unwrap();
1434            assert!(
1435                result.is_ok(),
1436                "expected Ok with still-usable token, got: {:?}",
1437                result.unwrap_err()
1438            );
1439        }
1440    }
1441}