1use std::sync::atomic::{AtomicBool, Ordering};
2
3use tokio::sync::{Mutex, MutexGuard, Notify};
4
5use crate::refresher::Refresher;
6use crate::{ServiceToken, Token};
7
8#[derive(Debug, thiserror::Error)]
13pub(crate) enum AutoRefreshError {
14 #[error("No token found")]
16 NotFound,
17 #[error("Token has expired")]
19 Expired,
20 #[error("Auth error: {0}")]
22 Auth(#[from] crate::AuthError),
23}
24
25impl From<AutoRefreshError> for crate::AuthError {
26 fn from(err: AutoRefreshError) -> Self {
27 match err {
28 AutoRefreshError::NotFound => crate::AuthError::NotAuthenticated,
29 AutoRefreshError::Expired => crate::AuthError::TokenExpired,
30 AutoRefreshError::Auth(e) => e,
31 }
32 }
33}
34
35pub(crate) struct AutoRefresh<R> {
41 refresher: R,
42 state: Mutex<State>,
43 refresh_in_progress: AtomicBool,
49 refresh_notify: Notify,
50}
51
52struct State {
53 token: Option<Token>,
54}
55
56struct CancelGuard<'a> {
62 in_progress: &'a AtomicBool,
63 notify: &'a Notify,
64 defused: bool,
65}
66
67impl Drop for CancelGuard<'_> {
68 fn drop(&mut self) {
69 if !self.defused {
70 self.in_progress.store(false, Ordering::Release);
71 self.notify.notify_waiters();
72 }
73 }
74}
75
76impl CancelGuard<'_> {
77 fn defuse(&mut self) {
78 self.defused = true;
79 }
80}
81
82impl State {
83 fn service_token(&self) -> Result<ServiceToken, AutoRefreshError> {
84 let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
85 Ok(ServiceToken::new(token.access_token().clone()))
86 }
87
88 fn require_usable_token(&self) -> Result<ServiceToken, AutoRefreshError> {
89 let token = self.token.as_ref().ok_or(AutoRefreshError::NotFound)?;
90 if token.is_usable() {
91 Ok(ServiceToken::new(token.access_token().clone()))
92 } else {
93 Err(AutoRefreshError::Expired)
94 }
95 }
96}
97
98impl<R> AutoRefresh<R> {
99 pub(crate) fn new(refresher: R) -> Self {
105 Self {
106 refresher,
107 state: Mutex::new(State { token: None }),
108 refresh_in_progress: AtomicBool::new(false),
109 refresh_notify: Notify::new(),
110 }
111 }
112
113 pub(crate) fn with_token(refresher: R, token: Token) -> Self {
118 Self {
119 refresher,
120 state: Mutex::new(State { token: Some(token) }),
121 refresh_in_progress: AtomicBool::new(false),
122 refresh_notify: Notify::new(),
123 }
124 }
125}
126
127impl<R: Refresher> AutoRefresh<R> {
128 pub(crate) async fn get_token(&self) -> Result<ServiceToken, AutoRefreshError> {
130 let mut state = self.state.lock().await;
131
132 if state.token.is_none() {
133 return self.initial_auth(&mut state).await;
134 }
135
136 if !state.token.as_ref().is_some_and(|t| t.is_expired()) {
137 return state.service_token();
138 }
139
140 if self.refresh_in_progress.load(Ordering::Acquire) {
141 return self.wait_for_in_flight_refresh(state).await;
142 }
143
144 let Some(credential) = self.refresher.try_credential(state.token.as_mut()) else {
145 return state.require_usable_token();
146 };
147
148 self.refresh_in_progress.store(true, Ordering::Release);
149
150 if state.token.as_ref().is_some_and(|t| t.is_usable()) {
151 self.refresh_non_blocking(state, credential).await
152 } else {
153 self.refresh_blocking(&mut state, credential).await
154 }
155 }
156
157 async fn initial_auth(&self, state: &mut State) -> Result<ServiceToken, AutoRefreshError> {
161 let Some(credential) = self.refresher.try_credential(None) else {
162 return Err(AutoRefreshError::NotFound);
163 };
164 self.refresh_in_progress.store(true, Ordering::Release);
165 let mut guard = CancelGuard {
166 in_progress: &self.refresh_in_progress,
167 notify: &self.refresh_notify,
168 defused: false,
169 };
170 match self.refresher.refresh(&credential).await {
171 Ok(new_token) => {
172 guard.defuse();
173 self.refresher.save(&new_token);
174 let service_token = ServiceToken::new(new_token.access_token().clone());
175 state.token = Some(new_token);
176 self.refresh_in_progress.store(false, Ordering::Release);
177 Ok(service_token)
178 }
179 Err(err) => {
180 guard.defuse();
181 self.refresh_in_progress.store(false, Ordering::Release);
182 Err(AutoRefreshError::Auth(err))
183 }
184 }
185 }
186
187 async fn wait_for_in_flight_refresh(
193 &self,
194 state: MutexGuard<'_, State>,
195 ) -> Result<ServiceToken, AutoRefreshError> {
196 if let Ok(token) = state.service_token() {
197 if state.token.as_ref().is_some_and(|t| t.is_usable()) {
198 return Ok(token);
199 }
200 }
201 let notified = self.refresh_notify.notified();
204 drop(state);
205 notified.await;
206 let state = self.state.lock().await;
208 state.require_usable_token()
209 }
210
211 async fn refresh_non_blocking(
221 &self,
222 state: MutexGuard<'_, State>,
223 credential: R::Credential,
224 ) -> Result<ServiceToken, AutoRefreshError> {
225 let current_service_token = state.service_token()?;
226 drop(state);
227
228 let mut guard = CancelGuard {
229 in_progress: &self.refresh_in_progress,
230 notify: &self.refresh_notify,
231 defused: false,
232 };
233
234 match self.refresher.refresh(&credential).await {
235 Ok(new_token) => {
236 guard.defuse();
237 self.refresher.save(&new_token);
238 let mut state = self.state.lock().await;
239 state.token = Some(new_token);
240 self.refresh_in_progress.store(false, Ordering::Release);
241 }
242 Err(err) => {
243 guard.defuse();
244 tracing::warn!(%err, "token refresh failed (token still usable)");
245 let mut state = self.state.lock().await;
246 if let Some(token) = state.token.as_mut() {
247 self.refresher.restore(token, credential);
248 }
249 self.refresh_in_progress.store(false, Ordering::Release);
250 }
251 }
252
253 self.refresh_notify.notify_waiters();
254 Ok(current_service_token)
255 }
256
257 async fn refresh_blocking(
266 &self,
267 state: &mut State,
268 credential: R::Credential,
269 ) -> Result<ServiceToken, AutoRefreshError> {
270 let mut guard = CancelGuard {
271 in_progress: &self.refresh_in_progress,
272 notify: &self.refresh_notify,
273 defused: false,
274 };
275 match self.refresher.refresh(&credential).await {
276 Ok(new_token) => {
277 guard.defuse();
278 self.refresher.save(&new_token);
279 let service_token = ServiceToken::new(new_token.access_token().clone());
280 state.token = Some(new_token);
281 self.refresh_in_progress.store(false, Ordering::Release);
282 Ok(service_token)
283 }
284 Err(err) => {
285 guard.defuse();
286 tracing::warn!(%err, "token refresh failed");
287 if let Some(token) = state.token.as_mut() {
288 self.refresher.restore(token, credential);
289 }
290 self.refresh_in_progress.store(false, Ordering::Release);
291 Err(AutoRefreshError::Expired)
292 }
293 }
294 }
295}
296
297#[cfg(test)]
298#[allow(clippy::unwrap_used)]
299mod tests {
300 use super::*;
301 use crate::oauth_refresher::OAuthRefresher;
302 use crate::SecretToken;
303 use mocktail::prelude::*;
304 use stack_profile::ProfileStore;
305 use std::sync::Arc;
306 use std::time::{SystemTime, UNIX_EPOCH};
307
308 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
309 let now = SystemTime::now()
310 .duration_since(UNIX_EPOCH)
311 .unwrap()
312 .as_secs();
313
314 Token {
315 access_token: SecretToken::new(access),
316 token_type: "Bearer".to_string(),
317 expires_at: now + expires_in,
318 refresh_token: if refresh {
319 Some(SecretToken::new("test-refresh-token"))
320 } else {
321 None
322 },
323 region: None,
324 client_id: None,
325 device_instance_id: None,
326 }
327 }
328
329 fn refresh_response_json(access: &str) -> serde_json::Value {
330 serde_json::json!({
331 "access_token": access,
332 "token_type": "Bearer",
333 "expires_in": 3600,
334 "refresh_token": "new-refresh-token"
335 })
336 }
337
338 fn error_json(error: &str) -> serde_json::Value {
339 serde_json::json!({
340 "error": error,
341 "error_description": format!("{error} occurred")
342 })
343 }
344
345 async fn start_server(mocks: MockSet) -> MockServer {
346 let server = MockServer::new_http("auto-refresh-test").with_mocks(mocks);
347 server.start().await.unwrap();
348 server
349 }
350
351 fn auto_refresh_with_token(
352 dir: &tempfile::TempDir,
353 server: &MockServer,
354 token: Token,
355 ) -> AutoRefresh<OAuthRefresher> {
356 let store = ProfileStore::new(dir.path());
357 store.save_profile(&token).unwrap();
358 let refresher = OAuthRefresher::new(
359 Some(store),
360 server.url(""),
361 "cli",
362 "ap-southeast-2.aws",
363 None,
364 );
365 AutoRefresh::with_token(refresher, token)
366 }
367
368 mod given_no_cached_token {
369 use super::*;
370
371 #[tokio::test]
372 async fn returns_not_found_for_oauth() {
373 let server = start_server(MockSet::new()).await;
374 let store = ProfileStore::new("/tmp/nonexistent");
375 let refresher = OAuthRefresher::new(
376 Some(store),
377 server.url(""),
378 "cli",
379 "ap-southeast-2.aws",
380 None,
381 );
382 let strategy = AutoRefresh::new(refresher);
383
384 let err = strategy.get_token().await.unwrap_err();
385
386 assert!(
387 matches!(err, AutoRefreshError::NotFound),
388 "expected NotFound, got: {err:?}"
389 );
390 }
391 }
392
393 mod given_fresh_token {
394 use super::*;
395
396 #[tokio::test]
397 async fn returns_cached_token() {
398 let dir = tempfile::tempdir().unwrap();
399 let server = start_server(MockSet::new()).await;
400 let strategy =
401 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
402
403 let token = strategy.get_token().await.unwrap();
404
405 assert_eq!(
406 token.as_str(),
407 "my-access-token",
408 "should return the cached access token"
409 );
410 }
411
412 #[tokio::test]
413 async fn caches_across_calls() {
414 let dir = tempfile::tempdir().unwrap();
415 let server = start_server(MockSet::new()).await;
416 let strategy =
417 auto_refresh_with_token(&dir, &server, make_token("my-access-token", 3600, false));
418
419 let token1 = strategy.get_token().await.unwrap();
420 assert_eq!(
421 token1.as_str(),
422 "my-access-token",
423 "first call should return the cached token"
424 );
425
426 std::fs::remove_file(dir.path().join("auth.json")).unwrap();
428
429 let token2 = strategy.get_token().await.unwrap();
430 assert_eq!(
431 token2.as_str(),
432 "my-access-token",
433 "second call should return the cached token even after file deletion"
434 );
435 }
436
437 #[tokio::test]
438 async fn does_not_trigger_refresh() {
439 let mut mocks = MockSet::new();
441 mocks.mock(|when, then| {
442 when.post().path("/oauth/token");
443 then.internal_server_error()
444 .json(error_json("should_not_be_called"));
445 });
446 let server = start_server(mocks).await;
447 let dir = tempfile::tempdir().unwrap();
448 let strategy =
449 auto_refresh_with_token(&dir, &server, make_token("fresh-token", 3600, true));
450
451 let token = strategy.get_token().await.unwrap();
452
453 assert_eq!(
454 token.as_str(),
455 "fresh-token",
456 "should return fresh token without triggering refresh"
457 );
458 }
459 }
460
461 mod given_fully_expired_token {
462 use super::*;
463
464 mod without_refresh_token {
465 use super::*;
466
467 #[tokio::test]
468 async fn returns_expired() {
469 let dir = tempfile::tempdir().unwrap();
470 let server = start_server(MockSet::new()).await;
471 let strategy =
472 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, false));
473
474 let err = strategy.get_token().await.unwrap_err();
475
476 assert!(
477 matches!(err, AutoRefreshError::Expired),
478 "expected Expired, got: {err:?}"
479 );
480 }
481 }
482
483 mod with_refresh_token {
484 use super::*;
485
486 #[tokio::test]
487 async fn refreshes_and_returns_new_token() {
488 let mut mocks = MockSet::new();
489 mocks.mock(|when, then| {
490 when.post().path("/oauth/token");
491 then.json(refresh_response_json("refreshed-token"));
492 });
493 let server = start_server(mocks).await;
494 let dir = tempfile::tempdir().unwrap();
495 let strategy =
496 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
497
498 let token = strategy.get_token().await.unwrap();
499
500 assert_eq!(
501 token.as_str(),
502 "refreshed-token",
503 "should return the refreshed token"
504 );
505 }
506
507 #[tokio::test]
508 async fn persists_refreshed_token_to_disk() {
509 let mut mocks = MockSet::new();
510 mocks.mock(|when, then| {
511 when.post().path("/oauth/token");
512 then.json(refresh_response_json("refreshed-token"));
513 });
514 let server = start_server(mocks).await;
515 let dir = tempfile::tempdir().unwrap();
516 let strategy =
517 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
518
519 let _ = strategy.get_token().await.unwrap();
520
521 let store = ProfileStore::new(dir.path());
523 let on_disk: Token = store.load_profile().unwrap();
524 assert_eq!(
525 on_disk.access_token().as_str(),
526 "refreshed-token",
527 "refreshed token should be persisted to disk"
528 );
529 }
530
531 #[tokio::test]
532 async fn returns_expired_on_refresh_failure() {
533 let mut mocks = MockSet::new();
534 mocks.mock(|when, then| {
535 when.post().path("/oauth/token");
536 then.bad_request().json(error_json("invalid_grant"));
537 });
538 let server = start_server(mocks).await;
539 let dir = tempfile::tempdir().unwrap();
540 let strategy =
541 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
542
543 let err = strategy.get_token().await.unwrap_err();
544
545 assert!(
546 matches!(err, AutoRefreshError::Expired),
547 "expected Expired after failed refresh, got: {err:?}"
548 );
549 }
550
551 #[tokio::test]
552 async fn restores_refresh_token_after_failure() {
553 let mut mocks = MockSet::new();
554 mocks.mock(|when, then| {
555 when.post().path("/oauth/token");
556 then.bad_request().json(error_json("invalid_grant"));
557 });
558 let server = start_server(mocks).await;
559 let dir = tempfile::tempdir().unwrap();
560 let strategy =
561 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
562
563 let err = strategy.get_token().await.unwrap_err();
565 assert!(
566 matches!(err, AutoRefreshError::Expired),
567 "expected Expired on first attempt, got: {err:?}"
568 );
569
570 let state = strategy.state.lock().await;
572 assert!(
573 state.token.is_some(),
574 "token should still be cached after failed refresh"
575 );
576 assert!(
577 state.token.as_ref().unwrap().refresh_token().is_some(),
578 "refresh token should be restored for retry"
579 );
580 drop(state);
581
582 server.mocks().clear();
584 server.mocks().mock(|when, then| {
585 when.post().path("/oauth/token");
586 then.json(refresh_response_json("refreshed-token"));
587 });
588
589 let token = strategy.get_token().await.unwrap();
591 assert_eq!(
592 token.as_str(),
593 "refreshed-token",
594 "retry should succeed with restored refresh token"
595 );
596 }
597
598 #[tokio::test]
599 async fn sequential_calls_only_refresh_once() {
600 let mut mocks = MockSet::new();
601 mocks.mock(|when, then| {
602 when.post().path("/oauth/token");
603 then.json(refresh_response_json("refreshed-once"));
604 });
605 let server = start_server(mocks).await;
606 let dir = tempfile::tempdir().unwrap();
607 let strategy =
608 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
609
610 let token = strategy.get_token().await.unwrap();
612 assert_eq!(
613 token.as_str(),
614 "refreshed-once",
615 "first call should trigger refresh"
616 );
617
618 server.mocks().clear();
620 server.mocks().mock(|when, then| {
621 when.post().path("/oauth/token");
622 then.json(refresh_response_json("refreshed-twice"));
623 });
624
625 for _ in 0..4 {
627 let token = strategy.get_token().await.unwrap();
628 assert_eq!(
629 token.as_str(),
630 "refreshed-once",
631 "should return cached refreshed token, not trigger another refresh"
632 );
633 }
634 }
635
636 #[tokio::test]
637 async fn prevents_second_refresh_after_success() {
638 let mut mocks = MockSet::new();
639 mocks.mock(|when, then| {
640 when.post().path("/oauth/token");
641 then.json(refresh_response_json("refreshed-token"));
642 });
643 let server = start_server(mocks).await;
644 let dir = tempfile::tempdir().unwrap();
645 let strategy =
646 auto_refresh_with_token(&dir, &server, make_token("old-token", 0, true));
647
648 let token = strategy.get_token().await.unwrap();
650 assert_eq!(
651 token.as_str(),
652 "refreshed-token",
653 "first call should refresh the token"
654 );
655
656 server.mocks().clear();
658 server.mocks().mock(|when, then| {
659 when.post().path("/oauth/token");
660 then.bad_request().json(error_json("should_not_be_called"));
661 });
662
663 let token = strategy.get_token().await.unwrap();
666 assert_eq!(
667 token.as_str(),
668 "refreshed-token",
669 "second call should return cached refreshed token"
670 );
671 }
672 }
673 }
674
675 mod given_expiring_but_usable_token {
676 use super::*;
677
678 mod when_refresh_fails {
679 use super::*;
680
681 #[tokio::test]
682 async fn returns_current_token() {
683 let mut mocks = MockSet::new();
684 mocks.mock(|when, then| {
685 when.post().path("/oauth/token");
686 then.bad_request().json(error_json("server_error"));
687 });
688 let server = start_server(mocks).await;
689 let dir = tempfile::tempdir().unwrap();
690 let strategy =
693 auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
694
695 let token = strategy.get_token().await.unwrap();
698 assert_eq!(
699 token.as_str(),
700 "still-usable",
701 "should return still-usable token despite failed refresh"
702 );
703
704 let state = strategy.state.lock().await;
706 assert!(state.token.is_some(), "token should still be cached");
707 assert_eq!(
708 state.token.as_ref().unwrap().access_token().as_str(),
709 "still-usable",
710 "access token should be unchanged after failed refresh"
711 );
712 assert!(
713 state.token.as_ref().unwrap().refresh_token().is_some(),
714 "refresh token should be restored after failed refresh"
715 );
716 }
717
718 #[tokio::test]
719 async fn restores_refresh_token_for_retry() {
720 let mut mocks = MockSet::new();
721 mocks.mock(|when, then| {
722 when.post().path("/oauth/token");
723 then.bad_request().json(error_json("server_error"));
724 });
725 let server = start_server(mocks).await;
726 let dir = tempfile::tempdir().unwrap();
727 let strategy =
729 auto_refresh_with_token(&dir, &server, make_token("still-usable", 30, true));
730
731 let token = strategy.get_token().await.unwrap();
733 assert_eq!(
734 token.as_str(),
735 "still-usable",
736 "first call should return still-usable token"
737 );
738
739 server.mocks().clear();
741 server.mocks().mock(|when, then| {
742 when.post().path("/oauth/token");
743 then.json(refresh_response_json("refreshed-token"));
744 });
745
746 let token = strategy.get_token().await.unwrap();
748 assert!(
749 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
750 "expected old or refreshed token, got: {}",
751 token.as_str()
752 );
753
754 let state = strategy.state.lock().await;
756 assert_eq!(
757 state.token.as_ref().unwrap().access_token().as_str(),
758 "refreshed-token",
759 "cache should hold the refreshed token after retry"
760 );
761 }
762 }
763 }
764
765 mod given_concurrent_callers {
766 use super::*;
767
768 #[tokio::test]
769 async fn returns_usable_token_while_refreshing() {
770 let mut mocks = MockSet::new();
771 mocks.mock(|when, then| {
772 when.post().path("/oauth/token");
773 then.json(refresh_response_json("refreshed-token"));
774 });
775 let server = start_server(mocks).await;
776 let dir = tempfile::tempdir().unwrap();
777 let strategy = Arc::new(auto_refresh_with_token(
778 &dir,
779 &server,
780 make_token("still-usable", 30, true),
781 ));
782
783 let s1 = Arc::clone(&strategy);
784 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
785
786 let s2 = Arc::clone(&strategy);
787 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
788
789 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
790 let token_a = result_a.unwrap();
791 let token_b = result_b.unwrap();
792
793 assert!(
794 token_a.as_str() == "still-usable" || token_a.as_str() == "refreshed-token",
795 "unexpected token_a: {}",
796 token_a.as_str()
797 );
798 assert!(
799 token_b.as_str() == "still-usable" || token_b.as_str() == "refreshed-token",
800 "unexpected token_b: {}",
801 token_b.as_str()
802 );
803 }
804
805 #[tokio::test]
806 async fn blocks_until_refresh_completes() {
807 let mut mocks = MockSet::new();
808 mocks.mock(|when, then| {
809 when.post().path("/oauth/token");
810 then.json(refresh_response_json("refreshed-token"));
811 });
812 let server = start_server(mocks).await;
813 let dir = tempfile::tempdir().unwrap();
814 let strategy = Arc::new(auto_refresh_with_token(
815 &dir,
816 &server,
817 make_token("expired-token", 0, true),
818 ));
819
820 let s1 = Arc::clone(&strategy);
821 let handle_a = tokio::spawn(async move { s1.get_token().await.unwrap() });
822
823 let s2 = Arc::clone(&strategy);
824 let handle_b = tokio::spawn(async move { s2.get_token().await.unwrap() });
825
826 let (result_a, result_b) = tokio::join!(handle_a, handle_b);
827 let token_a = result_a.unwrap();
828 let token_b = result_b.unwrap();
829
830 assert_eq!(
831 token_a.as_str(),
832 "refreshed-token",
833 "caller a should receive refreshed token"
834 );
835 assert_eq!(
836 token_b.as_str(),
837 "refreshed-token",
838 "caller b should receive refreshed token"
839 );
840 }
841 }
842}
843
844#[cfg(test)]
845#[allow(clippy::unwrap_used)]
846mod stress_tests {
847 use super::*;
848 use crate::oauth_refresher::OAuthRefresher;
849 use crate::SecretToken;
850 use stack_profile::ProfileStore;
851 use std::sync::atomic::{AtomicUsize, Ordering};
852 use std::sync::Arc;
853 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
854
855 #[derive(Clone)]
857 struct CountingState {
858 total: Arc<AtomicUsize>,
859 current: Arc<AtomicUsize>,
860 peak: Arc<AtomicUsize>,
861 }
862
863 impl CountingState {
864 fn new() -> Self {
865 Self {
866 total: Arc::new(AtomicUsize::new(0)),
867 current: Arc::new(AtomicUsize::new(0)),
868 peak: Arc::new(AtomicUsize::new(0)),
869 }
870 }
871
872 fn enter(&self) {
873 self.total.fetch_add(1, Ordering::SeqCst);
874 let prev = self.current.fetch_add(1, Ordering::SeqCst);
875 self.peak.fetch_max(prev + 1, Ordering::SeqCst);
876 }
877
878 fn exit(&self) {
879 self.current.fetch_sub(1, Ordering::SeqCst);
880 }
881
882 fn peak(&self) -> usize {
883 self.peak.load(Ordering::SeqCst)
884 }
885
886 fn total(&self) -> usize {
887 self.total.load(Ordering::SeqCst)
888 }
889 }
890
891 #[derive(Clone)]
892 struct DelayedRefreshState {
893 counting: CountingState,
894 delay: Duration,
895 }
896
897 async fn delayed_refresh_handler(
898 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
899 ) -> axum::Json<serde_json::Value> {
900 state.counting.enter();
901 tokio::time::sleep(state.delay).await;
902 state.counting.exit();
903 axum::Json(serde_json::json!({
904 "access_token": "refreshed-token",
905 "token_type": "Bearer",
906 "expires_in": 3600,
907 "refresh_token": "new-refresh-token"
908 }))
909 }
910
911 async fn delayed_error_handler(
912 axum::extract::State(state): axum::extract::State<DelayedRefreshState>,
913 ) -> (axum::http::StatusCode, axum::Json<serde_json::Value>) {
914 state.counting.enter();
915 tokio::time::sleep(state.delay).await;
916 state.counting.exit();
917 (
918 axum::http::StatusCode::BAD_REQUEST,
919 axum::Json(serde_json::json!({
920 "error": "invalid_grant",
921 "error_description": "invalid_grant occurred"
922 })),
923 )
924 }
925
926 async fn start_axum_server<H, T>(
927 handler: H,
928 state: DelayedRefreshState,
929 ) -> (url::Url, CountingState)
930 where
931 H: axum::handler::Handler<T, DelayedRefreshState> + Clone + Send + 'static,
932 T: 'static,
933 {
934 let counting = state.counting.clone();
935 let app = axum::Router::new()
936 .route("/oauth/token", axum::routing::post(handler))
937 .with_state(state);
938 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
939 let addr = listener.local_addr().unwrap();
940 tokio::spawn(async move {
941 axum::serve(listener, app).await.unwrap();
942 });
943 let base_url = url::Url::parse(&format!("http://{addr}")).unwrap();
944 (base_url, counting)
945 }
946
947 fn make_token(access: &str, expires_in: u64, refresh: bool) -> Token {
948 let now = SystemTime::now()
949 .duration_since(UNIX_EPOCH)
950 .unwrap()
951 .as_secs();
952
953 Token {
954 access_token: SecretToken::new(access),
955 token_type: "Bearer".to_string(),
956 expires_at: now + expires_in,
957 refresh_token: if refresh {
958 Some(SecretToken::new("test-refresh-token"))
959 } else {
960 None
961 },
962 region: None,
963 client_id: None,
964 device_instance_id: None,
965 }
966 }
967
968 fn auto_refresh_with_token(
969 dir: &tempfile::TempDir,
970 base_url: &url::Url,
971 token: Token,
972 ) -> AutoRefresh<OAuthRefresher> {
973 let store = ProfileStore::new(dir.path());
974 store.save_profile(&token).unwrap();
975 let refresher = OAuthRefresher::new(
976 Some(store),
977 base_url.clone(),
978 "cli",
979 "ap-southeast-2.aws",
980 None,
981 );
982 AutoRefresh::with_token(refresher, token)
983 }
984
985 const CONCURRENCY: usize = 50;
986
987 mod given_fresh_token {
988 use super::*;
989
990 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
991 async fn all_callers_return_immediately() {
992 let counting = CountingState::new();
993 let state = DelayedRefreshState {
994 counting: counting.clone(),
995 delay: Duration::from_millis(500),
996 };
997 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
998 let dir = tempfile::tempdir().unwrap();
999 let strategy = Arc::new(auto_refresh_with_token(
1000 &dir,
1001 &base_url,
1002 make_token("fresh-token", 3600, true),
1003 ));
1004
1005 let start = Instant::now();
1006 let mut handles = Vec::with_capacity(CONCURRENCY);
1007 for _ in 0..CONCURRENCY {
1008 let s = Arc::clone(&strategy);
1009 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1010 }
1011
1012 let results: Vec<_> = {
1013 let mut results = Vec::with_capacity(handles.len());
1014 for handle in handles {
1015 results.push(handle.await.unwrap());
1016 }
1017 results
1018 };
1019 let elapsed = start.elapsed();
1020
1021 for token in &results {
1022 assert_eq!(
1023 token.as_str(),
1024 "fresh-token",
1025 "all callers should receive the fresh token"
1026 );
1027 }
1028
1029 assert!(
1030 elapsed < Duration::from_millis(200),
1031 "expected < 200ms for fresh tokens, got {:?}",
1032 elapsed
1033 );
1034 assert_eq!(stats.total(), 0, "no refresh requests should be made");
1035 }
1036 }
1037
1038 mod given_expiring_but_usable_token {
1039 use super::*;
1040
1041 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1042 async fn non_blocking_reads_during_refresh() {
1043 let counting = CountingState::new();
1044 let state = DelayedRefreshState {
1045 counting: counting.clone(),
1046 delay: Duration::from_millis(500),
1047 };
1048 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1049 let dir = tempfile::tempdir().unwrap();
1050 let strategy = Arc::new(auto_refresh_with_token(
1051 &dir,
1052 &base_url,
1053 make_token("still-usable", 30, true),
1054 ));
1055
1056 let start = Instant::now();
1057 let mut handles = Vec::with_capacity(CONCURRENCY);
1058 for _ in 0..CONCURRENCY {
1059 let s = Arc::clone(&strategy);
1060 handles.push(tokio::spawn(async move {
1061 let call_start = Instant::now();
1062 let token = s.get_token().await.unwrap();
1063 (token, call_start.elapsed())
1064 }));
1065 }
1066
1067 let results: Vec<_> = {
1068 let mut results = Vec::with_capacity(handles.len());
1069 for handle in handles {
1070 results.push(handle.await.unwrap());
1071 }
1072 results
1073 };
1074 let elapsed = start.elapsed();
1075
1076 for (token, _) in &results {
1077 assert!(
1078 token.as_str() == "still-usable" || token.as_str() == "refreshed-token",
1079 "unexpected token: {}",
1080 token.as_str()
1081 );
1082 }
1083
1084 let fast_callers = results
1085 .iter()
1086 .filter(|(_, dur)| *dur < Duration::from_millis(100))
1087 .count();
1088 assert!(
1089 fast_callers >= CONCURRENCY - 1,
1090 "expected at least {} fast callers, got {} (total elapsed: {:?})",
1091 CONCURRENCY - 1,
1092 fast_callers,
1093 elapsed
1094 );
1095
1096 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1097 assert_eq!(stats.total(), 1, "total refresh requests");
1098 }
1099
1100 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1105 async fn waiters_receive_token_when_expiry_crosses() {
1106 let refresh_delay = Duration::from_millis(1500);
1111 let counting = CountingState::new();
1112 let state = DelayedRefreshState {
1113 counting: counting.clone(),
1114 delay: refresh_delay,
1115 };
1116 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1117 let dir = tempfile::tempdir().unwrap();
1118 let strategy = Arc::new(auto_refresh_with_token(
1119 &dir,
1120 &base_url,
1121 make_token("expiring-soon", 1, true),
1122 ));
1123
1124 let first = strategy.get_token().await.unwrap();
1126 assert_eq!(
1127 first.as_str(),
1128 "expiring-soon",
1129 "first caller should receive the expiring token"
1130 );
1131
1132 tokio::time::sleep(Duration::from_millis(1100)).await;
1134
1135 let mut handles = Vec::with_capacity(CONCURRENCY);
1138 for _ in 0..CONCURRENCY {
1139 let s = Arc::clone(&strategy);
1140 handles.push(tokio::spawn(async move { s.get_token().await }));
1141 }
1142
1143 let results: Vec<_> = {
1144 let mut results = Vec::with_capacity(handles.len());
1145 for handle in handles {
1146 results.push(handle.await.unwrap());
1147 }
1148 results
1149 };
1150
1151 for (i, result) in results.iter().enumerate() {
1153 assert!(
1154 result.is_ok(),
1155 "caller {i} got Err({:?}), expected Ok",
1156 result.as_ref().unwrap_err()
1157 );
1158 assert_eq!(
1159 result.as_ref().unwrap().as_str(),
1160 "refreshed-token",
1161 "caller {i} should receive the refreshed token"
1162 );
1163 }
1164
1165 assert_eq!(stats.total(), 1, "only one refresh request should be made");
1166 }
1167 }
1168
1169 mod given_fully_expired_token {
1170 use super::*;
1171
1172 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1173 async fn all_callers_block_until_refresh() {
1174 let refresh_delay = Duration::from_millis(200);
1175 let counting = CountingState::new();
1176 let state = DelayedRefreshState {
1177 counting: counting.clone(),
1178 delay: refresh_delay,
1179 };
1180 let (base_url, stats) = start_axum_server(delayed_refresh_handler, state).await;
1181 let dir = tempfile::tempdir().unwrap();
1182 let strategy = Arc::new(auto_refresh_with_token(
1183 &dir,
1184 &base_url,
1185 make_token("expired-token", 0, true),
1186 ));
1187
1188 let start = Instant::now();
1189 let mut handles = Vec::with_capacity(CONCURRENCY);
1190 for _ in 0..CONCURRENCY {
1191 let s = Arc::clone(&strategy);
1192 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1193 }
1194
1195 let results: Vec<_> = {
1196 let mut results = Vec::with_capacity(handles.len());
1197 for handle in handles {
1198 results.push(handle.await.unwrap());
1199 }
1200 results
1201 };
1202 let elapsed = start.elapsed();
1203
1204 for token in &results {
1205 assert_eq!(
1206 token.as_str(),
1207 "refreshed-token",
1208 "all callers should receive refreshed token"
1209 );
1210 }
1211
1212 assert!(
1213 elapsed < refresh_delay + Duration::from_millis(200),
1214 "expected < {:?} for blocked callers, got {:?}",
1215 refresh_delay + Duration::from_millis(200),
1216 elapsed
1217 );
1218
1219 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1220 assert_eq!(stats.total(), 1, "total refresh requests");
1221 }
1222
1223 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1224 async fn all_callers_receive_expired_on_failure() {
1225 let counting = CountingState::new();
1226 let state = DelayedRefreshState {
1227 counting: counting.clone(),
1228 delay: Duration::from_millis(10),
1229 };
1230 let (base_url, stats) = start_axum_server(delayed_error_handler, state).await;
1231 let dir = tempfile::tempdir().unwrap();
1232 let strategy = Arc::new(auto_refresh_with_token(
1233 &dir,
1234 &base_url,
1235 make_token("expired-token", 0, true),
1236 ));
1237
1238 let mut handles = Vec::with_capacity(CONCURRENCY);
1239 for _ in 0..CONCURRENCY {
1240 let s = Arc::clone(&strategy);
1241 handles.push(tokio::spawn(async move { s.get_token().await }));
1242 }
1243
1244 let results: Vec<_> = {
1245 let mut results = Vec::with_capacity(handles.len());
1246 for handle in handles {
1247 results.push(handle.await.unwrap());
1248 }
1249 results
1250 };
1251
1252 for result in &results {
1253 assert!(result.is_err(), "expected Expired error, got Ok");
1254 let err = result.as_ref().unwrap_err();
1255 assert!(
1256 matches!(err, AutoRefreshError::Expired),
1257 "expected Expired, got: {err:?}"
1258 );
1259 }
1260
1261 let state = strategy.state.lock().await;
1262 assert!(
1263 state.token.as_ref().unwrap().refresh_token().is_some(),
1264 "refresh token should be restored after failed refresh"
1265 );
1266 drop(state);
1267
1268 assert_eq!(stats.peak(), 1, "peak concurrency to refresh endpoint");
1269 assert!(
1270 stats.total() >= 1,
1271 "at least one refresh attempt should be made"
1272 );
1273 }
1274
1275 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1276 async fn retry_succeeds_after_failure() {
1277 let counting1 = CountingState::new();
1279 let state1 = DelayedRefreshState {
1280 counting: counting1.clone(),
1281 delay: Duration::from_millis(50),
1282 };
1283 let (base_url, _) = start_axum_server(delayed_error_handler, state1).await;
1284 let dir = tempfile::tempdir().unwrap();
1285 let strategy = Arc::new(auto_refresh_with_token(
1286 &dir,
1287 &base_url,
1288 make_token("expired-token", 0, true),
1289 ));
1290
1291 let mut handles = Vec::with_capacity(CONCURRENCY);
1292 for _ in 0..CONCURRENCY {
1293 let s = Arc::clone(&strategy);
1294 handles.push(tokio::spawn(async move { s.get_token().await }));
1295 }
1296
1297 let results: Vec<_> = {
1298 let mut results = Vec::with_capacity(handles.len());
1299 for handle in handles {
1300 results.push(handle.await.unwrap());
1301 }
1302 results
1303 };
1304
1305 for result in &results {
1306 assert!(
1307 result.is_err(),
1308 "first wave: expected Expired, got Ok({})",
1309 result.as_ref().unwrap().as_str()
1310 );
1311 }
1312
1313 let counting2 = CountingState::new();
1315 let state2 = DelayedRefreshState {
1316 counting: counting2.clone(),
1317 delay: Duration::from_millis(50),
1318 };
1319 let (base_url2, stats2) = start_axum_server(delayed_refresh_handler, state2).await;
1320
1321 let strategy2 = Arc::new(auto_refresh_with_token(
1322 &dir,
1323 &base_url2,
1324 make_token("expired-token", 0, true),
1325 ));
1326
1327 let mut handles = Vec::with_capacity(CONCURRENCY);
1328 for _ in 0..CONCURRENCY {
1329 let s = Arc::clone(&strategy2);
1330 handles.push(tokio::spawn(async move { s.get_token().await.unwrap() }));
1331 }
1332
1333 let results: Vec<_> = {
1334 let mut results = Vec::with_capacity(handles.len());
1335 for handle in handles {
1336 results.push(handle.await.unwrap());
1337 }
1338 results
1339 };
1340
1341 for token in &results {
1342 assert_eq!(
1343 token.as_str(),
1344 "refreshed-token",
1345 "retry callers should receive refreshed token"
1346 );
1347 }
1348
1349 assert_eq!(stats2.total(), 1, "only one retry refresh should be made");
1350 }
1351 }
1352
1353 mod given_cancelled_refresh {
1354 use super::*;
1355
1356 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1360 async fn blocked_callers_recover_after_cancellation() {
1361 let counting = CountingState::new();
1362 let state = DelayedRefreshState {
1363 counting: counting.clone(),
1364 delay: Duration::from_secs(10), };
1366 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1367 let dir = tempfile::tempdir().unwrap();
1368 let strategy = Arc::new(auto_refresh_with_token(
1369 &dir,
1370 &base_url,
1371 make_token("expired-token", 0, true),
1372 ));
1373
1374 let s = Arc::clone(&strategy);
1376 let handle = tokio::spawn(async move { s.get_token().await });
1377 tokio::time::sleep(Duration::from_millis(100)).await;
1378
1379 handle.abort();
1381 let _ = handle.await;
1382
1383 let s = Arc::clone(&strategy);
1387 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1388
1389 assert!(
1390 result.is_ok(),
1391 "get_token() should not hang after cancelled blocking refresh"
1392 );
1393 }
1394
1395 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1399 async fn non_blocking_callers_recover_after_cancellation() {
1400 let counting = CountingState::new();
1401 let state = DelayedRefreshState {
1402 counting: counting.clone(),
1403 delay: Duration::from_secs(10), };
1405 let (base_url, _) = start_axum_server(delayed_refresh_handler, state).await;
1406 let dir = tempfile::tempdir().unwrap();
1407 let strategy = Arc::new(auto_refresh_with_token(
1409 &dir,
1410 &base_url,
1411 make_token("still-usable", 30, true),
1412 ));
1413
1414 let s = Arc::clone(&strategy);
1417 let handle = tokio::spawn(async move { s.get_token().await });
1418 tokio::time::sleep(Duration::from_millis(100)).await;
1419
1420 handle.abort();
1422 let _ = handle.await;
1423
1424 let s = Arc::clone(&strategy);
1427 let result = tokio::time::timeout(Duration::from_secs(2), s.get_token()).await;
1428
1429 assert!(
1430 result.is_ok(),
1431 "get_token() should not hang after cancelled non-blocking refresh"
1432 );
1433 let result = result.unwrap();
1434 assert!(
1435 result.is_ok(),
1436 "expected Ok with still-usable token, got: {:?}",
1437 result.unwrap_err()
1438 );
1439 }
1440 }
1441}