1use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials, DEFAULT_UNIVERSE_DOMAIN};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97const METADATA_FLAVOR_VALUE: &str = "Google";
98const METADATA_FLAVOR: &str = "metadata-flavor";
99const METADATA_ROOT: &str = "http://metadata.google.internal";
100const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102const MDS_NOT_FOUND_ERROR: &str = concat!(
104 "Could not fetch an auth token to authenticate with Google Cloud. ",
105 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106 "and you have not configured local credentials for development and testing. ",
107 "To setup local credentials, run `gcloud auth application-default login`. ",
108 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114 T: CachedTokenProvider,
115{
116 quota_project_id: Option<String>,
117 universe_domain: Option<String>,
118 token_provider: T,
119}
120
121#[derive(Debug, Default)]
133pub struct Builder {
134 endpoint: Option<String>,
135 quota_project_id: Option<String>,
136 scopes: Option<Vec<String>>,
137 universe_domain: Option<String>,
138 created_by_adc: bool,
139 retry_builder: RetryTokenProviderBuilder,
140}
141
142impl Builder {
143 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
158 self.endpoint = Some(endpoint.into());
159 self
160 }
161
162 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
171 self.quota_project_id = Some(quota_project_id.into());
172 self
173 }
174
175 pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
182 self.universe_domain = Some(universe_domain.into());
183 self
184 }
185
186 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
195 where
196 I: IntoIterator<Item = S>,
197 S: Into<String>,
198 {
199 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
200 self
201 }
202
203 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
218 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
219 self
220 }
221
222 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
238 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
239 self
240 }
241
242 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
263 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
264 self
265 }
266
267 pub(crate) fn from_adc() -> Self {
269 Self {
270 created_by_adc: true,
271 ..Default::default()
272 }
273 }
274
275 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
276 let final_endpoint: String;
277 let endpoint_overridden: bool;
278
279 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
281 final_endpoint = format!("http://{host_from_env}");
283 endpoint_overridden = true;
284 } else if let Some(builder_endpoint) = self.endpoint {
285 final_endpoint = builder_endpoint;
287 endpoint_overridden = true;
288 } else {
289 final_endpoint = METADATA_ROOT.to_string();
291 endpoint_overridden = false;
292 };
293
294 let tp = MDSAccessTokenProvider::builder()
295 .endpoint(final_endpoint)
296 .maybe_scopes(self.scopes)
297 .endpoint_overridden(endpoint_overridden)
298 .created_by_adc(self.created_by_adc)
299 .build();
300 self.retry_builder.build(tp)
301 }
302
303 pub fn build(self) -> BuildResult<Credentials> {
305 let mdsc = MDSCredentials {
306 quota_project_id: self.quota_project_id.clone(),
307 universe_domain: self.universe_domain.clone(),
308 token_provider: TokenCache::new(self.build_token_provider()),
309 };
310 Ok(Credentials {
311 inner: Arc::new(mdsc),
312 })
313 }
314}
315
316#[async_trait::async_trait]
317impl<T> CredentialsProvider for MDSCredentials<T>
318where
319 T: CachedTokenProvider,
320{
321 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
322 let cached_token = self.token_provider.token(extensions).await?;
323 build_cacheable_headers(&cached_token, &self.quota_project_id)
324 }
325
326 async fn universe_domain(&self) -> Option<String> {
327 if self.universe_domain.is_some() {
328 return self.universe_domain.clone();
329 }
330 return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
331 }
332}
333
334#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
335struct MDSTokenResponse {
336 access_token: String,
337 #[serde(skip_serializing_if = "Option::is_none")]
338 expires_in: Option<u64>,
339 token_type: String,
340}
341
342#[derive(Debug, Clone, Default, Builder)]
343struct MDSAccessTokenProvider {
344 #[builder(into)]
345 scopes: Option<Vec<String>>,
346 #[builder(into)]
347 endpoint: String,
348 endpoint_overridden: bool,
349 created_by_adc: bool,
350}
351
352impl MDSAccessTokenProvider {
353 fn error_message(&self) -> &str {
361 if self.use_adc_message() {
362 MDS_NOT_FOUND_ERROR
363 } else {
364 "failed to fetch token"
365 }
366 }
367
368 fn use_adc_message(&self) -> bool {
369 self.created_by_adc && !self.endpoint_overridden
370 }
371}
372
373#[async_trait]
374impl TokenProvider for MDSAccessTokenProvider {
375 async fn token(&self) -> Result<Token> {
376 let client = Client::new();
377 let request = client
378 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
379 .header(
380 METADATA_FLAVOR,
381 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
382 );
383 let scopes = self.scopes.as_ref().map(|v| v.join(","));
386 let request = scopes
387 .into_iter()
388 .fold(request, |r, s| r.query(&[("scopes", s)]));
389
390 let response = request
395 .send()
396 .await
397 .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
398 if !response.status().is_success() {
400 let err = crate::errors::from_http_response(response, self.error_message()).await;
401 return Err(err);
402 }
403 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
404 CredentialsError::from_source(!e.is_decode(), e)
408 })?;
409 let token = Token {
410 token: response.access_token,
411 token_type: response.token_type,
412 expires_at: response
413 .expires_in
414 .map(|d| Instant::now() + Duration::from_secs(d)),
415 metadata: None,
416 };
417 Ok(token)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::credentials::QUOTA_PROJECT_KEY;
425 use crate::credentials::tests::{
426 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
427 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
428 get_token_type_from_headers,
429 };
430 use crate::errors;
431 use crate::errors::CredentialsError;
432 use crate::token::tests::MockTokenProvider;
433 use http::HeaderValue;
434 use http::header::AUTHORIZATION;
435 use httptest::cycle;
436 use httptest::matchers::{all_of, contains, request, url_decoded};
437 use httptest::responders::{json_encoded, status_code};
438 use httptest::{Expectation, Server};
439 use reqwest::StatusCode;
440 use scoped_env::ScopedEnv;
441 use serial_test::{parallel, serial};
442 use std::error::Error;
443 use test_case::test_case;
444 use url::Url;
445
446 type TestResult = anyhow::Result<()>;
447
448 #[tokio::test]
449 #[parallel]
450 async fn test_mds_retries_on_transient_failures() -> TestResult {
451 let mut server = Server::run();
452 server.expect(
453 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
454 .times(3)
455 .respond_with(status_code(503)),
456 );
457
458 let provider = Builder::default()
459 .with_endpoint(format!("http://{}", server.addr()))
460 .with_retry_policy(get_mock_auth_retry_policy(3))
461 .with_backoff_policy(get_mock_backoff_policy())
462 .with_retry_throttler(get_mock_retry_throttler())
463 .build_token_provider();
464
465 let err = provider.token().await.unwrap_err();
466 assert!(!err.is_transient());
467 server.verify_and_clear();
468 Ok(())
469 }
470
471 #[tokio::test]
472 #[parallel]
473 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
474 let mut server = Server::run();
475 server.expect(
476 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
477 .times(1)
478 .respond_with(status_code(401)),
479 );
480
481 let provider = Builder::default()
482 .with_endpoint(format!("http://{}", server.addr()))
483 .with_retry_policy(get_mock_auth_retry_policy(1))
484 .with_backoff_policy(get_mock_backoff_policy())
485 .with_retry_throttler(get_mock_retry_throttler())
486 .build_token_provider();
487
488 let err = provider.token().await.unwrap_err();
489 assert!(!err.is_transient());
490 server.verify_and_clear();
491 Ok(())
492 }
493
494 #[tokio::test]
495 #[parallel]
496 async fn test_mds_retries_for_success() -> TestResult {
497 let mut server = Server::run();
498 let response = MDSTokenResponse {
499 access_token: "test-access-token".to_string(),
500 expires_in: Some(3600),
501 token_type: "test-token-type".to_string(),
502 };
503
504 server.expect(
505 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
506 .times(3)
507 .respond_with(cycle![
508 status_code(503).body("try-again"),
509 status_code(503).body("try-again"),
510 status_code(200)
511 .append_header("Content-Type", "application/json")
512 .body(serde_json::to_string(&response).unwrap()),
513 ]),
514 );
515
516 let provider = Builder::default()
517 .with_endpoint(format!("http://{}", server.addr()))
518 .with_retry_policy(get_mock_auth_retry_policy(3))
519 .with_backoff_policy(get_mock_backoff_policy())
520 .with_retry_throttler(get_mock_retry_throttler())
521 .build_token_provider();
522
523 let token = provider.token().await?;
524 assert_eq!(token.token, "test-access-token");
525
526 server.verify_and_clear();
527 Ok(())
528 }
529
530 #[test]
531 fn validate_default_endpoint_urls() {
532 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
533 assert!(default_endpoint_address.is_ok());
534
535 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
536 assert!(token_endpoint_address.is_ok());
537 }
538
539 #[tokio::test]
540 async fn headers_success() -> TestResult {
541 let token = Token {
542 token: "test-token".to_string(),
543 token_type: "Bearer".to_string(),
544 expires_at: None,
545 metadata: None,
546 };
547
548 let mut mock = MockTokenProvider::new();
549 mock.expect_token().times(1).return_once(|| Ok(token));
550
551 let mdsc = MDSCredentials {
552 quota_project_id: None,
553 universe_domain: None,
554 token_provider: TokenCache::new(mock),
555 };
556
557 let mut extensions = Extensions::new();
558 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
559 let (headers, entity_tag) = match cached_headers {
560 CacheableResource::New { entity_tag, data } => (data, entity_tag),
561 CacheableResource::NotModified => unreachable!("expecting new headers"),
562 };
563 let token = headers.get(AUTHORIZATION).unwrap();
564 assert_eq!(headers.len(), 1, "{headers:?}");
565 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
566 assert!(token.is_sensitive());
567
568 extensions.insert(entity_tag);
569
570 let cached_headers = mdsc.headers(extensions).await?;
571
572 match cached_headers {
573 CacheableResource::New { .. } => unreachable!("expecting new headers"),
574 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
575 };
576 Ok(())
577 }
578
579 #[tokio::test]
580 async fn headers_failure() {
581 let mut mock = MockTokenProvider::new();
582 mock.expect_token()
583 .times(1)
584 .return_once(|| Err(errors::non_retryable_from_str("fail")));
585
586 let mdsc = MDSCredentials {
587 quota_project_id: None,
588 universe_domain: None,
589 token_provider: TokenCache::new(mock),
590 };
591 assert!(mdsc.headers(Extensions::new()).await.is_err());
592 }
593
594 #[test]
595 fn error_message_with_adc() {
596 let provider = MDSAccessTokenProvider::builder()
597 .endpoint("http://127.0.0.1")
598 .created_by_adc(true)
599 .endpoint_overridden(false)
600 .build();
601
602 let want = MDS_NOT_FOUND_ERROR;
603 let got = provider.error_message();
604 assert!(got.contains(want), "{got}, {provider:?}");
605 }
606
607 #[test_case(false, false)]
608 #[test_case(false, true)]
609 #[test_case(true, true)]
610 fn error_message_without_adc(adc: bool, overridden: bool) {
611 let provider = MDSAccessTokenProvider::builder()
612 .endpoint("http://127.0.0.1")
613 .created_by_adc(adc)
614 .endpoint_overridden(overridden)
615 .build();
616
617 let not_want = MDS_NOT_FOUND_ERROR;
618 let got = provider.error_message();
619 assert!(!got.contains(not_want), "{got}, {provider:?}");
620 }
621
622 #[tokio::test]
623 #[serial]
624 async fn adc_no_mds() -> TestResult {
625 let err = Builder::from_adc()
626 .build_token_provider()
627 .token()
628 .await
629 .unwrap_err();
630
631 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
632 assert!(original_err.is_transient());
633 assert!(
634 original_err.to_string().contains("application-default"),
635 "display={err}, debug={err:?}"
636 );
637 let source = find_source_error::<reqwest::Error>(&err);
638 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
639
640 Ok(())
641 }
642
643 #[tokio::test]
644 #[serial]
645 async fn adc_overridden_mds() -> TestResult {
646 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
647
648 let err = Builder::from_adc()
649 .build_token_provider()
650 .token()
651 .await
652 .unwrap_err();
653
654 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
655
656 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
657 assert!(original_err.is_transient());
658 assert!(
659 !original_err.to_string().contains("application-default"),
660 "display={err}, debug={err:?}"
661 );
662 let source = find_source_error::<reqwest::Error>(&err);
663 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
664
665 Ok(())
666 }
667
668 #[tokio::test]
669 #[serial]
670 async fn builder_no_mds() -> TestResult {
671 let e = Builder::default()
672 .build_token_provider()
673 .token()
674 .await
675 .err()
676 .unwrap();
677
678 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
679 assert!(original_err.is_transient());
680 assert!(
681 !format!("{:?}", original_err.source()).contains("application-default"),
682 "{e:?}"
683 );
684
685 Ok(())
686 }
687
688 #[tokio::test]
689 #[serial]
690 async fn test_gce_metadata_host_env_var() -> TestResult {
691 let server = Server::run();
692 let scopes = ["scope1", "scope2"];
693 let response = MDSTokenResponse {
694 access_token: "test-access-token".to_string(),
695 expires_in: Some(3600),
696 token_type: "test-token-type".to_string(),
697 };
698 server.expect(
699 Expectation::matching(all_of![
700 request::path(format!("{MDS_DEFAULT_URI}/token")),
701 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
702 ])
703 .respond_with(json_encoded(response)),
704 );
705
706 let addr = server.addr().to_string();
707 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
708 let mdsc = Builder::default()
709 .with_scopes(["scope1", "scope2"])
710 .build()
711 .unwrap();
712 let headers = mdsc.headers(Extensions::new()).await.unwrap();
713 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
714
715 assert_eq!(
716 get_token_from_headers(headers).unwrap(),
717 "test-access-token"
718 );
719 Ok(())
720 }
721
722 #[tokio::test]
723 #[parallel]
724 async fn headers_success_with_quota_project() -> TestResult {
725 let server = Server::run();
726 let scopes = ["scope1", "scope2"];
727 let response = MDSTokenResponse {
728 access_token: "test-access-token".to_string(),
729 expires_in: Some(3600),
730 token_type: "test-token-type".to_string(),
731 };
732 server.expect(
733 Expectation::matching(all_of![
734 request::path(format!("{MDS_DEFAULT_URI}/token")),
735 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
736 ])
737 .respond_with(json_encoded(response)),
738 );
739
740 let mdsc = Builder::default()
741 .with_scopes(["scope1", "scope2"])
742 .with_endpoint(format!("http://{}", server.addr()))
743 .with_quota_project_id("test-project")
744 .build()?;
745
746 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
747 let token = headers.get(AUTHORIZATION).unwrap();
748 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
749
750 assert_eq!(headers.len(), 2, "{headers:?}");
751 assert_eq!(
752 token,
753 HeaderValue::from_static("test-token-type test-access-token")
754 );
755 assert!(token.is_sensitive());
756 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
757 assert!(!quota_project.is_sensitive());
758
759 Ok(())
760 }
761
762 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
763 #[parallel]
764 async fn token_caching() -> TestResult {
765 let mut server = Server::run();
766 let scopes = vec!["scope1".to_string()];
767 let response = MDSTokenResponse {
768 access_token: "test-access-token".to_string(),
769 expires_in: Some(3600),
770 token_type: "test-token-type".to_string(),
771 };
772 server.expect(
773 Expectation::matching(all_of![
774 request::path(format!("{MDS_DEFAULT_URI}/token")),
775 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
776 ])
777 .times(1)
778 .respond_with(json_encoded(response)),
779 );
780
781 let mdsc = Builder::default()
782 .with_scopes(scopes)
783 .with_endpoint(format!("http://{}", server.addr()))
784 .build()?;
785 let headers = mdsc.headers(Extensions::new()).await?;
786 assert_eq!(
787 get_token_from_headers(headers).unwrap(),
788 "test-access-token"
789 );
790 let headers = mdsc.headers(Extensions::new()).await?;
791 assert_eq!(
792 get_token_from_headers(headers).unwrap(),
793 "test-access-token"
794 );
795
796 server.verify_and_clear();
798
799 Ok(())
800 }
801
802 #[tokio::test(start_paused = true)]
803 #[parallel]
804 async fn token_provider_full() -> TestResult {
805 let server = Server::run();
806 let scopes = vec!["scope1".to_string()];
807 let response = MDSTokenResponse {
808 access_token: "test-access-token".to_string(),
809 expires_in: Some(3600),
810 token_type: "test-token-type".to_string(),
811 };
812 server.expect(
813 Expectation::matching(all_of![
814 request::path(format!("{MDS_DEFAULT_URI}/token")),
815 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
816 ])
817 .respond_with(json_encoded(response)),
818 );
819
820 let token = Builder::default()
821 .with_endpoint(format!("http://{}", server.addr()))
822 .with_scopes(scopes)
823 .build_token_provider()
824 .token()
825 .await?;
826
827 let now = tokio::time::Instant::now();
828 assert_eq!(token.token, "test-access-token");
829 assert_eq!(token.token_type, "test-token-type");
830 assert!(
831 token
832 .expires_at
833 .is_some_and(|d| d >= now + Duration::from_secs(3600))
834 );
835
836 Ok(())
837 }
838
839 #[tokio::test(start_paused = true)]
840 #[parallel]
841 async fn token_provider_full_no_scopes() -> TestResult {
842 let server = Server::run();
843 let response = MDSTokenResponse {
844 access_token: "test-access-token".to_string(),
845 expires_in: Some(3600),
846 token_type: "test-token-type".to_string(),
847 };
848 server.expect(
849 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
850 .respond_with(json_encoded(response)),
851 );
852
853 let token = Builder::default()
854 .with_endpoint(format!("http://{}", server.addr()))
855 .build_token_provider()
856 .token()
857 .await?;
858
859 let now = Instant::now();
860 assert_eq!(token.token, "test-access-token");
861 assert_eq!(token.token_type, "test-token-type");
862 assert!(
863 token
864 .expires_at
865 .is_some_and(|d| d == now + Duration::from_secs(3600))
866 );
867
868 Ok(())
869 }
870
871 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
872 #[parallel]
873 async fn credential_provider_full() -> TestResult {
874 let server = Server::run();
875 let scopes = vec!["scope1".to_string()];
876 let response = MDSTokenResponse {
877 access_token: "test-access-token".to_string(),
878 expires_in: None,
879 token_type: "test-token-type".to_string(),
880 };
881 server.expect(
882 Expectation::matching(all_of![
883 request::path(format!("{MDS_DEFAULT_URI}/token")),
884 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
885 ])
886 .respond_with(json_encoded(response)),
887 );
888
889 let mdsc = Builder::default()
890 .with_endpoint(format!("http://{}", server.addr()))
891 .with_scopes(scopes)
892 .build()?;
893 let headers = mdsc.headers(Extensions::new()).await?;
894 assert_eq!(
895 get_token_from_headers(headers.clone()).unwrap(),
896 "test-access-token"
897 );
898 assert_eq!(
899 get_token_type_from_headers(headers).unwrap(),
900 "test-token-type"
901 );
902
903 Ok(())
904 }
905
906 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
907 #[parallel]
908 async fn credentials_headers_retryable_error() -> TestResult {
909 let server = Server::run();
910 let scopes = vec!["scope1".to_string()];
911 server.expect(
912 Expectation::matching(all_of![
913 request::path(format!("{MDS_DEFAULT_URI}/token")),
914 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
915 ])
916 .respond_with(status_code(503)),
917 );
918
919 let mdsc = Builder::default()
920 .with_endpoint(format!("http://{}", server.addr()))
921 .with_scopes(scopes)
922 .build()?;
923 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
924 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
925 assert!(original_err.is_transient());
926 let source = find_source_error::<reqwest::Error>(&err);
927 assert!(
928 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
929 "{err:?}"
930 );
931
932 Ok(())
933 }
934
935 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
936 #[parallel]
937 async fn credentials_headers_nonretryable_error() -> TestResult {
938 let server = Server::run();
939 let scopes = vec!["scope1".to_string()];
940 server.expect(
941 Expectation::matching(all_of![
942 request::path(format!("{MDS_DEFAULT_URI}/token")),
943 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
944 ])
945 .respond_with(status_code(401)),
946 );
947
948 let mdsc = Builder::default()
949 .with_endpoint(format!("http://{}", server.addr()))
950 .with_scopes(scopes)
951 .build()?;
952
953 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
954 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
955 assert!(!original_err.is_transient());
956 let source = find_source_error::<reqwest::Error>(&err);
957 assert!(
958 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
959 "{err:?}"
960 );
961
962 Ok(())
963 }
964
965 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
966 #[parallel]
967 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
968 let server = Server::run();
969 let scopes = vec!["scope1".to_string()];
970 server.expect(
971 Expectation::matching(all_of![
972 request::path(format!("{MDS_DEFAULT_URI}/token")),
973 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
974 ])
975 .respond_with(json_encoded("bad json")),
976 );
977
978 let mdsc = Builder::default()
979 .with_endpoint(format!("http://{}", server.addr()))
980 .with_scopes(scopes)
981 .build()?;
982
983 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
984 assert!(!e.is_transient());
985
986 Ok(())
987 }
988
989 #[tokio::test]
990 async fn get_default_universe_domain_success() -> TestResult {
991 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
992 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
993 Ok(())
994 }
995
996 #[tokio::test]
997 async fn get_custom_universe_domain_success() -> TestResult {
998 let universe_domain = "test-universe";
999 let universe_domain_response = Builder::default()
1000 .with_universe_domain(universe_domain)
1001 .build()?
1002 .universe_domain()
1003 .await
1004 .unwrap();
1005 assert_eq!(universe_domain_response, universe_domain);
1006
1007 Ok(())
1008 }
1009}