1use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
78use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97pub(crate) const METADATA_FLAVOR_VALUE: &str = "Google";
98pub(crate) const METADATA_FLAVOR: &str = "metadata-flavor";
99pub(crate) const METADATA_ROOT: &str = "http://metadata.google.internal";
100pub(crate) const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101pub(crate) const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102const MDS_NOT_FOUND_ERROR: &str = concat!(
104 "Could not fetch an auth token to authenticate with Google Cloud. ",
105 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106 "and you have not configured local credentials for development and testing. ",
107 "To setup local credentials, run `gcloud auth application-default login`. ",
108 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114 T: CachedTokenProvider,
115{
116 quota_project_id: Option<String>,
117 token_provider: T,
118}
119
120#[derive(Debug, Default)]
132pub struct Builder {
133 endpoint: Option<String>,
134 quota_project_id: Option<String>,
135 scopes: Option<Vec<String>>,
136 created_by_adc: bool,
137 retry_builder: RetryTokenProviderBuilder,
138}
139
140impl Builder {
141 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
156 self.endpoint = Some(endpoint.into());
157 self
158 }
159
160 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
169 self.quota_project_id = Some(quota_project_id.into());
170 self
171 }
172
173 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
182 where
183 I: IntoIterator<Item = S>,
184 S: Into<String>,
185 {
186 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
187 self
188 }
189
190 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
205 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
206 self
207 }
208
209 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
225 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
226 self
227 }
228
229 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
250 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
251 self
252 }
253
254 pub(crate) fn from_adc() -> Self {
256 Self {
257 created_by_adc: true,
258 ..Default::default()
259 }
260 }
261
262 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
263 let final_endpoint: String;
264 let endpoint_overridden: bool;
265
266 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
268 final_endpoint = format!("http://{host_from_env}");
270 endpoint_overridden = true;
271 } else if let Some(builder_endpoint) = self.endpoint {
272 final_endpoint = builder_endpoint;
274 endpoint_overridden = true;
275 } else {
276 final_endpoint = METADATA_ROOT.to_string();
278 endpoint_overridden = false;
279 };
280
281 let tp = MDSAccessTokenProvider::builder()
282 .endpoint(final_endpoint)
283 .maybe_scopes(self.scopes)
284 .endpoint_overridden(endpoint_overridden)
285 .created_by_adc(self.created_by_adc)
286 .build();
287 self.retry_builder.build(tp)
288 }
289
290 pub fn build(self) -> BuildResult<Credentials> {
292 Ok(self.build_access_token_credentials()?.into())
293 }
294
295 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
312 let mdsc = MDSCredentials {
313 quota_project_id: self.quota_project_id.clone(),
314 token_provider: TokenCache::new(self.build_token_provider()),
315 };
316 Ok(AccessTokenCredentials {
317 inner: Arc::new(mdsc),
318 })
319 }
320}
321
322#[async_trait::async_trait]
323impl<T> CredentialsProvider for MDSCredentials<T>
324where
325 T: CachedTokenProvider,
326{
327 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
328 let cached_token = self.token_provider.token(extensions).await?;
329 build_cacheable_headers(&cached_token, &self.quota_project_id)
330 }
331}
332
333#[async_trait::async_trait]
334impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
335where
336 T: CachedTokenProvider,
337{
338 async fn access_token(&self) -> Result<AccessToken> {
339 let token = self.token_provider.token(Extensions::new()).await?;
340 token.into()
341 }
342}
343
344#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
345struct MDSTokenResponse {
346 access_token: String,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 expires_in: Option<u64>,
349 token_type: String,
350}
351
352#[derive(Debug, Clone, Default, Builder)]
353struct MDSAccessTokenProvider {
354 #[builder(into)]
355 scopes: Option<Vec<String>>,
356 #[builder(into)]
357 endpoint: String,
358 endpoint_overridden: bool,
359 created_by_adc: bool,
360}
361
362impl MDSAccessTokenProvider {
363 fn error_message(&self) -> &str {
371 if self.use_adc_message() {
372 MDS_NOT_FOUND_ERROR
373 } else {
374 "failed to fetch token"
375 }
376 }
377
378 fn use_adc_message(&self) -> bool {
379 self.created_by_adc && !self.endpoint_overridden
380 }
381}
382
383#[async_trait]
384impl TokenProvider for MDSAccessTokenProvider {
385 async fn token(&self) -> Result<Token> {
386 let client = Client::new();
387 let request = client
388 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
389 .header(
390 METADATA_FLAVOR,
391 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
392 );
393 let scopes = self.scopes.as_ref().map(|v| v.join(","));
396 let request = scopes
397 .into_iter()
398 .fold(request, |r, s| r.query(&[("scopes", s)]));
399
400 let response = request
405 .send()
406 .await
407 .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
408 if !response.status().is_success() {
410 let err = crate::errors::from_http_response(response, self.error_message()).await;
411 return Err(err);
412 }
413 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
414 CredentialsError::from_source(!e.is_decode(), e)
418 })?;
419 let token = Token {
420 token: response.access_token,
421 token_type: response.token_type,
422 expires_at: response
423 .expires_in
424 .map(|d| Instant::now() + Duration::from_secs(d)),
425 metadata: None,
426 };
427 Ok(token)
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
435 use crate::credentials::QUOTA_PROJECT_KEY;
436 use crate::credentials::tests::{
437 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
438 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
439 get_token_type_from_headers,
440 };
441 use crate::errors;
442 use crate::errors::CredentialsError;
443 use crate::token::tests::MockTokenProvider;
444 use http::HeaderValue;
445 use http::header::AUTHORIZATION;
446 use httptest::cycle;
447 use httptest::matchers::{all_of, contains, request, url_decoded};
448 use httptest::responders::{json_encoded, status_code};
449 use httptest::{Expectation, Server};
450 use reqwest::StatusCode;
451 use scoped_env::ScopedEnv;
452 use serial_test::{parallel, serial};
453 use std::error::Error;
454 use test_case::test_case;
455 use url::Url;
456
457 type TestResult = anyhow::Result<()>;
458
459 #[tokio::test]
460 #[parallel]
461 async fn test_mds_retries_on_transient_failures() -> TestResult {
462 let mut server = Server::run();
463 server.expect(
464 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
465 .times(3)
466 .respond_with(status_code(503)),
467 );
468
469 let provider = Builder::default()
470 .with_endpoint(format!("http://{}", server.addr()))
471 .with_retry_policy(get_mock_auth_retry_policy(3))
472 .with_backoff_policy(get_mock_backoff_policy())
473 .with_retry_throttler(get_mock_retry_throttler())
474 .build_token_provider();
475
476 let err = provider.token().await.unwrap_err();
477 assert!(!err.is_transient());
478 server.verify_and_clear();
479 Ok(())
480 }
481
482 #[tokio::test]
483 #[parallel]
484 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
485 let mut server = Server::run();
486 server.expect(
487 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
488 .times(1)
489 .respond_with(status_code(401)),
490 );
491
492 let provider = Builder::default()
493 .with_endpoint(format!("http://{}", server.addr()))
494 .with_retry_policy(get_mock_auth_retry_policy(1))
495 .with_backoff_policy(get_mock_backoff_policy())
496 .with_retry_throttler(get_mock_retry_throttler())
497 .build_token_provider();
498
499 let err = provider.token().await.unwrap_err();
500 assert!(!err.is_transient());
501 server.verify_and_clear();
502 Ok(())
503 }
504
505 #[tokio::test]
506 #[parallel]
507 async fn test_mds_retries_for_success() -> TestResult {
508 let mut server = Server::run();
509 let response = MDSTokenResponse {
510 access_token: "test-access-token".to_string(),
511 expires_in: Some(3600),
512 token_type: "test-token-type".to_string(),
513 };
514
515 server.expect(
516 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
517 .times(3)
518 .respond_with(cycle![
519 status_code(503).body("try-again"),
520 status_code(503).body("try-again"),
521 status_code(200)
522 .append_header("Content-Type", "application/json")
523 .body(serde_json::to_string(&response).unwrap()),
524 ]),
525 );
526
527 let provider = Builder::default()
528 .with_endpoint(format!("http://{}", server.addr()))
529 .with_retry_policy(get_mock_auth_retry_policy(3))
530 .with_backoff_policy(get_mock_backoff_policy())
531 .with_retry_throttler(get_mock_retry_throttler())
532 .build_token_provider();
533
534 let token = provider.token().await?;
535 assert_eq!(token.token, "test-access-token");
536
537 server.verify_and_clear();
538 Ok(())
539 }
540
541 #[test]
542 fn validate_default_endpoint_urls() {
543 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
544 assert!(default_endpoint_address.is_ok());
545
546 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
547 assert!(token_endpoint_address.is_ok());
548 }
549
550 #[tokio::test]
551 async fn headers_success() -> TestResult {
552 let token = Token {
553 token: "test-token".to_string(),
554 token_type: "Bearer".to_string(),
555 expires_at: None,
556 metadata: None,
557 };
558
559 let mut mock = MockTokenProvider::new();
560 mock.expect_token().times(1).return_once(|| Ok(token));
561
562 let mdsc = MDSCredentials {
563 quota_project_id: None,
564 token_provider: TokenCache::new(mock),
565 };
566
567 let mut extensions = Extensions::new();
568 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
569 let (headers, entity_tag) = match cached_headers {
570 CacheableResource::New { entity_tag, data } => (data, entity_tag),
571 CacheableResource::NotModified => unreachable!("expecting new headers"),
572 };
573 let token = headers.get(AUTHORIZATION).unwrap();
574 assert_eq!(headers.len(), 1, "{headers:?}");
575 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
576 assert!(token.is_sensitive());
577
578 extensions.insert(entity_tag);
579
580 let cached_headers = mdsc.headers(extensions).await?;
581
582 match cached_headers {
583 CacheableResource::New { .. } => unreachable!("expecting new headers"),
584 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
585 };
586 Ok(())
587 }
588
589 #[tokio::test]
590 async fn access_token_success() -> TestResult {
591 let server = Server::run();
592 let response = MDSTokenResponse {
593 access_token: "test-access-token".to_string(),
594 expires_in: Some(3600),
595 token_type: "Bearer".to_string(),
596 };
597 server.expect(
598 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
599 .respond_with(json_encoded(response)),
600 );
601
602 let creds = Builder::default()
603 .with_endpoint(format!("http://{}", server.addr()))
604 .build_access_token_credentials()
605 .unwrap();
606
607 let access_token = creds.access_token().await.unwrap();
608 assert_eq!(access_token.token, "test-access-token");
609
610 Ok(())
611 }
612
613 #[tokio::test]
614 async fn headers_failure() {
615 let mut mock = MockTokenProvider::new();
616 mock.expect_token()
617 .times(1)
618 .return_once(|| Err(errors::non_retryable_from_str("fail")));
619
620 let mdsc = MDSCredentials {
621 quota_project_id: None,
622 token_provider: TokenCache::new(mock),
623 };
624 assert!(mdsc.headers(Extensions::new()).await.is_err());
625 }
626
627 #[test]
628 fn error_message_with_adc() {
629 let provider = MDSAccessTokenProvider::builder()
630 .endpoint("http://127.0.0.1")
631 .created_by_adc(true)
632 .endpoint_overridden(false)
633 .build();
634
635 let want = MDS_NOT_FOUND_ERROR;
636 let got = provider.error_message();
637 assert!(got.contains(want), "{got}, {provider:?}");
638 }
639
640 #[test_case(false, false)]
641 #[test_case(false, true)]
642 #[test_case(true, true)]
643 fn error_message_without_adc(adc: bool, overridden: bool) {
644 let provider = MDSAccessTokenProvider::builder()
645 .endpoint("http://127.0.0.1")
646 .created_by_adc(adc)
647 .endpoint_overridden(overridden)
648 .build();
649
650 let not_want = MDS_NOT_FOUND_ERROR;
651 let got = provider.error_message();
652 assert!(!got.contains(not_want), "{got}, {provider:?}");
653 }
654
655 #[tokio::test]
656 #[serial]
657 async fn adc_no_mds() -> TestResult {
658 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
659 return Ok(());
661 };
662
663 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
664 assert!(
665 original_err.to_string().contains("application-default"),
666 "display={err}, debug={err:?}"
667 );
668
669 Ok(())
670 }
671
672 #[tokio::test]
673 #[serial]
674 async fn adc_overridden_mds() -> TestResult {
675 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
676
677 let err = Builder::from_adc()
678 .build_token_provider()
679 .token()
680 .await
681 .unwrap_err();
682
683 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
684
685 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
686 assert!(original_err.is_transient());
687 assert!(
688 !original_err.to_string().contains("application-default"),
689 "display={err}, debug={err:?}"
690 );
691 let source = find_source_error::<reqwest::Error>(&err);
692 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
693
694 Ok(())
695 }
696
697 #[tokio::test]
698 #[serial]
699 async fn builder_no_mds() -> TestResult {
700 let Err(e) = Builder::default().build_token_provider().token().await else {
701 return Ok(());
703 };
704
705 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
706 assert!(
707 !format!("{:?}", original_err.source()).contains("application-default"),
708 "{e:?}"
709 );
710
711 Ok(())
712 }
713
714 #[tokio::test]
715 #[serial]
716 async fn test_gce_metadata_host_env_var() -> TestResult {
717 let server = Server::run();
718 let scopes = ["scope1", "scope2"];
719 let response = MDSTokenResponse {
720 access_token: "test-access-token".to_string(),
721 expires_in: Some(3600),
722 token_type: "test-token-type".to_string(),
723 };
724 server.expect(
725 Expectation::matching(all_of![
726 request::path(format!("{MDS_DEFAULT_URI}/token")),
727 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
728 ])
729 .respond_with(json_encoded(response)),
730 );
731
732 let addr = server.addr().to_string();
733 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
734 let mdsc = Builder::default()
735 .with_scopes(["scope1", "scope2"])
736 .build()
737 .unwrap();
738 let headers = mdsc.headers(Extensions::new()).await.unwrap();
739 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
740
741 assert_eq!(
742 get_token_from_headers(headers).unwrap(),
743 "test-access-token"
744 );
745 Ok(())
746 }
747
748 #[tokio::test]
749 #[parallel]
750 async fn headers_success_with_quota_project() -> TestResult {
751 let server = Server::run();
752 let scopes = ["scope1", "scope2"];
753 let response = MDSTokenResponse {
754 access_token: "test-access-token".to_string(),
755 expires_in: Some(3600),
756 token_type: "test-token-type".to_string(),
757 };
758 server.expect(
759 Expectation::matching(all_of![
760 request::path(format!("{MDS_DEFAULT_URI}/token")),
761 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
762 ])
763 .respond_with(json_encoded(response)),
764 );
765
766 let mdsc = Builder::default()
767 .with_scopes(["scope1", "scope2"])
768 .with_endpoint(format!("http://{}", server.addr()))
769 .with_quota_project_id("test-project")
770 .build()?;
771
772 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
773 let token = headers.get(AUTHORIZATION).unwrap();
774 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
775
776 assert_eq!(headers.len(), 2, "{headers:?}");
777 assert_eq!(
778 token,
779 HeaderValue::from_static("test-token-type test-access-token")
780 );
781 assert!(token.is_sensitive());
782 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
783 assert!(!quota_project.is_sensitive());
784
785 Ok(())
786 }
787
788 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
789 #[parallel]
790 async fn token_caching() -> TestResult {
791 let mut server = Server::run();
792 let scopes = vec!["scope1".to_string()];
793 let response = MDSTokenResponse {
794 access_token: "test-access-token".to_string(),
795 expires_in: Some(3600),
796 token_type: "test-token-type".to_string(),
797 };
798 server.expect(
799 Expectation::matching(all_of![
800 request::path(format!("{MDS_DEFAULT_URI}/token")),
801 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
802 ])
803 .times(1)
804 .respond_with(json_encoded(response)),
805 );
806
807 let mdsc = Builder::default()
808 .with_scopes(scopes)
809 .with_endpoint(format!("http://{}", server.addr()))
810 .build()?;
811 let headers = mdsc.headers(Extensions::new()).await?;
812 assert_eq!(
813 get_token_from_headers(headers).unwrap(),
814 "test-access-token"
815 );
816 let headers = mdsc.headers(Extensions::new()).await?;
817 assert_eq!(
818 get_token_from_headers(headers).unwrap(),
819 "test-access-token"
820 );
821
822 server.verify_and_clear();
824
825 Ok(())
826 }
827
828 #[tokio::test(start_paused = true)]
829 #[parallel]
830 async fn token_provider_full() -> TestResult {
831 let server = Server::run();
832 let scopes = vec!["scope1".to_string()];
833 let response = MDSTokenResponse {
834 access_token: "test-access-token".to_string(),
835 expires_in: Some(3600),
836 token_type: "test-token-type".to_string(),
837 };
838 server.expect(
839 Expectation::matching(all_of![
840 request::path(format!("{MDS_DEFAULT_URI}/token")),
841 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
842 ])
843 .respond_with(json_encoded(response)),
844 );
845
846 let token = Builder::default()
847 .with_endpoint(format!("http://{}", server.addr()))
848 .with_scopes(scopes)
849 .build_token_provider()
850 .token()
851 .await?;
852
853 let now = tokio::time::Instant::now();
854 assert_eq!(token.token, "test-access-token");
855 assert_eq!(token.token_type, "test-token-type");
856 assert!(
857 token
858 .expires_at
859 .is_some_and(|d| d >= now + Duration::from_secs(3600))
860 );
861
862 Ok(())
863 }
864
865 #[tokio::test(start_paused = true)]
866 #[parallel]
867 async fn token_provider_full_no_scopes() -> TestResult {
868 let server = Server::run();
869 let response = MDSTokenResponse {
870 access_token: "test-access-token".to_string(),
871 expires_in: Some(3600),
872 token_type: "test-token-type".to_string(),
873 };
874 server.expect(
875 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
876 .respond_with(json_encoded(response)),
877 );
878
879 let token = Builder::default()
880 .with_endpoint(format!("http://{}", server.addr()))
881 .build_token_provider()
882 .token()
883 .await?;
884
885 let now = Instant::now();
886 assert_eq!(token.token, "test-access-token");
887 assert_eq!(token.token_type, "test-token-type");
888 assert!(
889 token
890 .expires_at
891 .is_some_and(|d| d == now + Duration::from_secs(3600))
892 );
893
894 Ok(())
895 }
896
897 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
898 #[parallel]
899 async fn credential_provider_full() -> TestResult {
900 let server = Server::run();
901 let scopes = vec!["scope1".to_string()];
902 let response = MDSTokenResponse {
903 access_token: "test-access-token".to_string(),
904 expires_in: None,
905 token_type: "test-token-type".to_string(),
906 };
907 server.expect(
908 Expectation::matching(all_of![
909 request::path(format!("{MDS_DEFAULT_URI}/token")),
910 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
911 ])
912 .respond_with(json_encoded(response)),
913 );
914
915 let mdsc = Builder::default()
916 .with_endpoint(format!("http://{}", server.addr()))
917 .with_scopes(scopes)
918 .build()?;
919 let headers = mdsc.headers(Extensions::new()).await?;
920 assert_eq!(
921 get_token_from_headers(headers.clone()).unwrap(),
922 "test-access-token"
923 );
924 assert_eq!(
925 get_token_type_from_headers(headers).unwrap(),
926 "test-token-type"
927 );
928
929 Ok(())
930 }
931
932 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
933 #[parallel]
934 async fn credentials_headers_retryable_error() -> TestResult {
935 let server = Server::run();
936 let scopes = vec!["scope1".to_string()];
937 server.expect(
938 Expectation::matching(all_of![
939 request::path(format!("{MDS_DEFAULT_URI}/token")),
940 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
941 ])
942 .respond_with(status_code(503)),
943 );
944
945 let mdsc = Builder::default()
946 .with_endpoint(format!("http://{}", server.addr()))
947 .with_scopes(scopes)
948 .build()?;
949 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
950 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
951 assert!(original_err.is_transient());
952 let source = find_source_error::<reqwest::Error>(&err);
953 assert!(
954 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
955 "{err:?}"
956 );
957
958 Ok(())
959 }
960
961 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
962 #[parallel]
963 async fn credentials_headers_nonretryable_error() -> TestResult {
964 let server = Server::run();
965 let scopes = vec!["scope1".to_string()];
966 server.expect(
967 Expectation::matching(all_of![
968 request::path(format!("{MDS_DEFAULT_URI}/token")),
969 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
970 ])
971 .respond_with(status_code(401)),
972 );
973
974 let mdsc = Builder::default()
975 .with_endpoint(format!("http://{}", server.addr()))
976 .with_scopes(scopes)
977 .build()?;
978
979 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
980 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
981 assert!(!original_err.is_transient());
982 let source = find_source_error::<reqwest::Error>(&err);
983 assert!(
984 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
985 "{err:?}"
986 );
987
988 Ok(())
989 }
990
991 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
992 #[parallel]
993 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
994 let server = Server::run();
995 let scopes = vec!["scope1".to_string()];
996 server.expect(
997 Expectation::matching(all_of![
998 request::path(format!("{MDS_DEFAULT_URI}/token")),
999 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1000 ])
1001 .respond_with(json_encoded("bad json")),
1002 );
1003
1004 let mdsc = Builder::default()
1005 .with_endpoint(format!("http://{}", server.addr()))
1006 .with_scopes(scopes)
1007 .build()?;
1008
1009 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1010 assert!(!e.is_transient());
1011
1012 Ok(())
1013 }
1014
1015 #[tokio::test]
1016 async fn get_default_universe_domain_success() -> TestResult {
1017 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1018 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1019 Ok(())
1020 }
1021}