1use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials, DEFAULT_UNIVERSE_DOMAIN};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97const METADATA_FLAVOR_VALUE: &str = "Google";
98const METADATA_FLAVOR: &str = "metadata-flavor";
99const METADATA_ROOT: &str = "http://metadata.google.internal";
100const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102const MDS_NOT_FOUND_ERROR: &str = concat!(
104 "Could not fetch an auth token to authenticate with Google Cloud. ",
105 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106 "and you have not configured local credentials for development and testing. ",
107 "To setup local credentials, run `gcloud auth application-default login`. ",
108 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114 T: CachedTokenProvider,
115{
116 quota_project_id: Option<String>,
117 universe_domain: Option<String>,
118 token_provider: T,
119}
120
121#[derive(Debug, Default)]
133pub struct Builder {
134 endpoint: Option<String>,
135 quota_project_id: Option<String>,
136 scopes: Option<Vec<String>>,
137 universe_domain: Option<String>,
138 created_by_adc: bool,
139 retry_builder: RetryTokenProviderBuilder,
140}
141
142impl Builder {
143 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
158 self.endpoint = Some(endpoint.into());
159 self
160 }
161
162 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
171 self.quota_project_id = Some(quota_project_id.into());
172 self
173 }
174
175 pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
182 self.universe_domain = Some(universe_domain.into());
183 self
184 }
185
186 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
195 where
196 I: IntoIterator<Item = S>,
197 S: Into<String>,
198 {
199 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
200 self
201 }
202
203 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
218 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
219 self
220 }
221
222 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
238 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
239 self
240 }
241
242 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
263 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
264 self
265 }
266
267 pub(crate) fn from_adc() -> Self {
269 Self {
270 created_by_adc: true,
271 ..Default::default()
272 }
273 }
274
275 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
276 let final_endpoint: String;
277 let endpoint_overridden: bool;
278
279 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
281 final_endpoint = format!("http://{host_from_env}");
283 endpoint_overridden = true;
284 } else if let Some(builder_endpoint) = self.endpoint {
285 final_endpoint = builder_endpoint;
287 endpoint_overridden = true;
288 } else {
289 final_endpoint = METADATA_ROOT.to_string();
291 endpoint_overridden = false;
292 };
293
294 let tp = MDSAccessTokenProvider::builder()
295 .endpoint(final_endpoint)
296 .maybe_scopes(self.scopes)
297 .endpoint_overridden(endpoint_overridden)
298 .created_by_adc(self.created_by_adc)
299 .build();
300 self.retry_builder.build(tp)
301 }
302
303 pub fn build(self) -> BuildResult<Credentials> {
305 let mdsc = MDSCredentials {
306 quota_project_id: self.quota_project_id.clone(),
307 universe_domain: self.universe_domain.clone(),
308 token_provider: TokenCache::new(self.build_token_provider()),
309 };
310 Ok(Credentials {
311 inner: Arc::new(mdsc),
312 })
313 }
314}
315
316#[async_trait::async_trait]
317impl<T> CredentialsProvider for MDSCredentials<T>
318where
319 T: CachedTokenProvider,
320{
321 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
322 let cached_token = self.token_provider.token(extensions).await?;
323 build_cacheable_headers(&cached_token, &self.quota_project_id)
324 }
325
326 async fn universe_domain(&self) -> Option<String> {
327 if self.universe_domain.is_some() {
328 return self.universe_domain.clone();
329 }
330 return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
331 }
332}
333
334#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
335struct MDSTokenResponse {
336 access_token: String,
337 #[serde(skip_serializing_if = "Option::is_none")]
338 expires_in: Option<u64>,
339 token_type: String,
340}
341
342#[derive(Debug, Clone, Default, Builder)]
343struct MDSAccessTokenProvider {
344 #[builder(into)]
345 scopes: Option<Vec<String>>,
346 #[builder(into)]
347 endpoint: String,
348 endpoint_overridden: bool,
349 created_by_adc: bool,
350}
351
352impl MDSAccessTokenProvider {
353 fn error_message(&self) -> &str {
361 if self.use_adc_message() {
362 MDS_NOT_FOUND_ERROR
363 } else {
364 "failed to fetch token"
365 }
366 }
367
368 fn use_adc_message(&self) -> bool {
369 self.created_by_adc && !self.endpoint_overridden
370 }
371}
372
373#[async_trait]
374impl TokenProvider for MDSAccessTokenProvider {
375 async fn token(&self) -> Result<Token> {
376 let client = Client::new();
377 let request = client
378 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
379 .header(
380 METADATA_FLAVOR,
381 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
382 );
383 let scopes = self.scopes.as_ref().map(|v| v.join(","));
386 let request = scopes
387 .into_iter()
388 .fold(request, |r, s| r.query(&[("scopes", s)]));
389
390 let response = request
395 .send()
396 .await
397 .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
398 if !response.status().is_success() {
400 let err = crate::errors::from_http_response(response, self.error_message()).await;
401 return Err(err);
402 }
403 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
404 CredentialsError::from_source(!e.is_decode(), e)
408 })?;
409 let token = Token {
410 token: response.access_token,
411 token_type: response.token_type,
412 expires_at: response
413 .expires_in
414 .map(|d| Instant::now() + Duration::from_secs(d)),
415 metadata: None,
416 };
417 Ok(token)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::credentials::QUOTA_PROJECT_KEY;
425 use crate::credentials::tests::{
426 get_headers_from_cache, get_mock_auth_retry_policy, get_mock_backoff_policy,
427 get_mock_retry_throttler, get_token_from_headers, get_token_type_from_headers,
428 };
429 use crate::errors;
430 use crate::token::tests::MockTokenProvider;
431 use http::HeaderValue;
432 use http::header::AUTHORIZATION;
433 use httptest::cycle;
434 use httptest::matchers::{all_of, contains, request, url_decoded};
435 use httptest::responders::{json_encoded, status_code};
436 use httptest::{Expectation, Server};
437 use reqwest::StatusCode;
438 use scoped_env::ScopedEnv;
439 use serial_test::{parallel, serial};
440 use std::error::Error;
441 use test_case::test_case;
442 use url::Url;
443
444 type TestResult = anyhow::Result<()>;
445
446 #[tokio::test]
447 #[parallel]
448 async fn test_mds_retries_on_transient_failures() -> TestResult {
449 let mut server = Server::run();
450 server.expect(
451 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
452 .times(3)
453 .respond_with(status_code(503)),
454 );
455
456 let provider = Builder::default()
457 .with_endpoint(format!("http://{}", server.addr()))
458 .with_retry_policy(get_mock_auth_retry_policy(3))
459 .with_backoff_policy(get_mock_backoff_policy())
460 .with_retry_throttler(get_mock_retry_throttler())
461 .build_token_provider();
462
463 let err = provider.token().await.unwrap_err();
464 assert!(err.is_transient());
465 server.verify_and_clear();
466 Ok(())
467 }
468
469 #[tokio::test]
470 #[parallel]
471 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
472 let mut server = Server::run();
473 server.expect(
474 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
475 .times(1)
476 .respond_with(status_code(401)),
477 );
478
479 let provider = Builder::default()
480 .with_endpoint(format!("http://{}", server.addr()))
481 .with_retry_policy(get_mock_auth_retry_policy(1))
482 .with_backoff_policy(get_mock_backoff_policy())
483 .with_retry_throttler(get_mock_retry_throttler())
484 .build_token_provider();
485
486 let err = provider.token().await.unwrap_err();
487 assert!(!err.is_transient());
488 server.verify_and_clear();
489 Ok(())
490 }
491
492 #[tokio::test]
493 #[parallel]
494 async fn test_mds_retries_for_success() -> TestResult {
495 let mut server = Server::run();
496 let response = MDSTokenResponse {
497 access_token: "test-access-token".to_string(),
498 expires_in: Some(3600),
499 token_type: "test-token-type".to_string(),
500 };
501
502 server.expect(
503 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
504 .times(3)
505 .respond_with(cycle![
506 status_code(503).body("try-again"),
507 status_code(503).body("try-again"),
508 status_code(200)
509 .append_header("Content-Type", "application/json")
510 .body(serde_json::to_string(&response).unwrap()),
511 ]),
512 );
513
514 let provider = Builder::default()
515 .with_endpoint(format!("http://{}", server.addr()))
516 .with_retry_policy(get_mock_auth_retry_policy(3))
517 .with_backoff_policy(get_mock_backoff_policy())
518 .with_retry_throttler(get_mock_retry_throttler())
519 .build_token_provider();
520
521 let token = provider.token().await?;
522 assert_eq!(token.token, "test-access-token");
523
524 server.verify_and_clear();
525 Ok(())
526 }
527
528 #[test]
529 fn validate_default_endpoint_urls() {
530 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
531 assert!(default_endpoint_address.is_ok());
532
533 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
534 assert!(token_endpoint_address.is_ok());
535 }
536
537 #[tokio::test]
538 async fn headers_success() -> TestResult {
539 let token = Token {
540 token: "test-token".to_string(),
541 token_type: "Bearer".to_string(),
542 expires_at: None,
543 metadata: None,
544 };
545
546 let mut mock = MockTokenProvider::new();
547 mock.expect_token().times(1).return_once(|| Ok(token));
548
549 let mdsc = MDSCredentials {
550 quota_project_id: None,
551 universe_domain: None,
552 token_provider: TokenCache::new(mock),
553 };
554
555 let mut extensions = Extensions::new();
556 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
557 let (headers, entity_tag) = match cached_headers {
558 CacheableResource::New { entity_tag, data } => (data, entity_tag),
559 CacheableResource::NotModified => unreachable!("expecting new headers"),
560 };
561 let token = headers.get(AUTHORIZATION).unwrap();
562 assert_eq!(headers.len(), 1, "{headers:?}");
563 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
564 assert!(token.is_sensitive());
565
566 extensions.insert(entity_tag);
567
568 let cached_headers = mdsc.headers(extensions).await?;
569
570 match cached_headers {
571 CacheableResource::New { .. } => unreachable!("expecting new headers"),
572 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
573 };
574 Ok(())
575 }
576
577 #[tokio::test]
578 async fn headers_failure() {
579 let mut mock = MockTokenProvider::new();
580 mock.expect_token()
581 .times(1)
582 .return_once(|| Err(errors::non_retryable_from_str("fail")));
583
584 let mdsc = MDSCredentials {
585 quota_project_id: None,
586 universe_domain: None,
587 token_provider: TokenCache::new(mock),
588 };
589 assert!(mdsc.headers(Extensions::new()).await.is_err());
590 }
591
592 #[test]
593 fn error_message_with_adc() {
594 let provider = MDSAccessTokenProvider::builder()
595 .endpoint("http://127.0.0.1")
596 .created_by_adc(true)
597 .endpoint_overridden(false)
598 .build();
599
600 let want = MDS_NOT_FOUND_ERROR;
601 let got = provider.error_message();
602 assert!(got.contains(want), "{got}, {provider:?}");
603 }
604
605 #[test_case(false, false)]
606 #[test_case(false, true)]
607 #[test_case(true, true)]
608 fn error_message_without_adc(adc: bool, overridden: bool) {
609 let provider = MDSAccessTokenProvider::builder()
610 .endpoint("http://127.0.0.1")
611 .created_by_adc(adc)
612 .endpoint_overridden(overridden)
613 .build();
614
615 let not_want = MDS_NOT_FOUND_ERROR;
616 let got = provider.error_message();
617 assert!(!got.contains(not_want), "{got}, {provider:?}");
618 }
619
620 #[tokio::test]
621 #[serial]
622 async fn adc_no_mds() -> TestResult {
623 let err = Builder::from_adc()
624 .build_token_provider()
625 .token()
626 .await
627 .unwrap_err();
628
629 assert!(err.is_transient(), "{err:?}");
630 assert!(
631 err.to_string().contains("application-default"),
632 "display={err}, debug={err:?}"
633 );
634 let source = err
635 .source()
636 .and_then(|e| e.downcast_ref::<reqwest::Error>());
637 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
638
639 Ok(())
640 }
641
642 #[tokio::test]
643 #[serial]
644 async fn adc_overridden_mds() -> TestResult {
645 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
646
647 let err = Builder::from_adc()
648 .build_token_provider()
649 .token()
650 .await
651 .unwrap_err();
652
653 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
654
655 assert!(err.is_transient(), "{err:?}");
656 assert!(
657 !err.to_string().contains("application-default"),
658 "display={err}, debug={err:?}"
659 );
660 let source = err
661 .source()
662 .and_then(|e| e.downcast_ref::<reqwest::Error>());
663 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
664
665 Ok(())
666 }
667
668 #[tokio::test]
669 #[serial]
670 async fn builder_no_mds() -> TestResult {
671 let e = Builder::default()
672 .build_token_provider()
673 .token()
674 .await
675 .err()
676 .unwrap();
677
678 assert!(e.is_transient(), "{e:?}");
679 assert!(
680 !format!("{:?}", e.source()).contains("application-default"),
681 "{e:?}"
682 );
683
684 Ok(())
685 }
686
687 #[tokio::test]
688 #[serial]
689 async fn test_gce_metadata_host_env_var() -> TestResult {
690 let server = Server::run();
691 let scopes = ["scope1", "scope2"];
692 let response = MDSTokenResponse {
693 access_token: "test-access-token".to_string(),
694 expires_in: Some(3600),
695 token_type: "test-token-type".to_string(),
696 };
697 server.expect(
698 Expectation::matching(all_of![
699 request::path(format!("{MDS_DEFAULT_URI}/token")),
700 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
701 ])
702 .respond_with(json_encoded(response)),
703 );
704
705 let addr = server.addr().to_string();
706 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
707 let mdsc = Builder::default()
708 .with_scopes(["scope1", "scope2"])
709 .build()
710 .unwrap();
711 let headers = mdsc.headers(Extensions::new()).await.unwrap();
712 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
713
714 assert_eq!(
715 get_token_from_headers(headers).unwrap(),
716 "test-access-token"
717 );
718 Ok(())
719 }
720
721 #[tokio::test]
722 #[parallel]
723 async fn headers_success_with_quota_project() -> TestResult {
724 let server = Server::run();
725 let scopes = ["scope1", "scope2"];
726 let response = MDSTokenResponse {
727 access_token: "test-access-token".to_string(),
728 expires_in: Some(3600),
729 token_type: "test-token-type".to_string(),
730 };
731 server.expect(
732 Expectation::matching(all_of![
733 request::path(format!("{MDS_DEFAULT_URI}/token")),
734 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
735 ])
736 .respond_with(json_encoded(response)),
737 );
738
739 let mdsc = Builder::default()
740 .with_scopes(["scope1", "scope2"])
741 .with_endpoint(format!("http://{}", server.addr()))
742 .with_quota_project_id("test-project")
743 .build()?;
744
745 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
746 let token = headers.get(AUTHORIZATION).unwrap();
747 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
748
749 assert_eq!(headers.len(), 2, "{headers:?}");
750 assert_eq!(
751 token,
752 HeaderValue::from_static("test-token-type test-access-token")
753 );
754 assert!(token.is_sensitive());
755 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
756 assert!(!quota_project.is_sensitive());
757
758 Ok(())
759 }
760
761 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
762 #[parallel]
763 async fn token_caching() -> TestResult {
764 let mut server = Server::run();
765 let scopes = vec!["scope1".to_string()];
766 let response = MDSTokenResponse {
767 access_token: "test-access-token".to_string(),
768 expires_in: Some(3600),
769 token_type: "test-token-type".to_string(),
770 };
771 server.expect(
772 Expectation::matching(all_of![
773 request::path(format!("{MDS_DEFAULT_URI}/token")),
774 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
775 ])
776 .times(1)
777 .respond_with(json_encoded(response)),
778 );
779
780 let mdsc = Builder::default()
781 .with_scopes(scopes)
782 .with_endpoint(format!("http://{}", server.addr()))
783 .build()?;
784 let headers = mdsc.headers(Extensions::new()).await?;
785 assert_eq!(
786 get_token_from_headers(headers).unwrap(),
787 "test-access-token"
788 );
789 let headers = mdsc.headers(Extensions::new()).await?;
790 assert_eq!(
791 get_token_from_headers(headers).unwrap(),
792 "test-access-token"
793 );
794
795 server.verify_and_clear();
797
798 Ok(())
799 }
800
801 #[tokio::test(start_paused = true)]
802 #[parallel]
803 async fn token_provider_full() -> TestResult {
804 let server = Server::run();
805 let scopes = vec!["scope1".to_string()];
806 let response = MDSTokenResponse {
807 access_token: "test-access-token".to_string(),
808 expires_in: Some(3600),
809 token_type: "test-token-type".to_string(),
810 };
811 server.expect(
812 Expectation::matching(all_of![
813 request::path(format!("{MDS_DEFAULT_URI}/token")),
814 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
815 ])
816 .respond_with(json_encoded(response)),
817 );
818
819 let token = Builder::default()
820 .with_endpoint(format!("http://{}", server.addr()))
821 .with_scopes(scopes)
822 .build_token_provider()
823 .token()
824 .await?;
825
826 let now = tokio::time::Instant::now();
827 assert_eq!(token.token, "test-access-token");
828 assert_eq!(token.token_type, "test-token-type");
829 assert!(
830 token
831 .expires_at
832 .is_some_and(|d| d >= now + Duration::from_secs(3600))
833 );
834
835 Ok(())
836 }
837
838 #[tokio::test(start_paused = true)]
839 #[parallel]
840 async fn token_provider_full_no_scopes() -> TestResult {
841 let server = Server::run();
842 let response = MDSTokenResponse {
843 access_token: "test-access-token".to_string(),
844 expires_in: Some(3600),
845 token_type: "test-token-type".to_string(),
846 };
847 server.expect(
848 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
849 .respond_with(json_encoded(response)),
850 );
851
852 let token = Builder::default()
853 .with_endpoint(format!("http://{}", server.addr()))
854 .build_token_provider()
855 .token()
856 .await?;
857
858 let now = Instant::now();
859 assert_eq!(token.token, "test-access-token");
860 assert_eq!(token.token_type, "test-token-type");
861 assert!(
862 token
863 .expires_at
864 .is_some_and(|d| d == now + Duration::from_secs(3600))
865 );
866
867 Ok(())
868 }
869
870 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
871 #[parallel]
872 async fn credential_provider_full() -> TestResult {
873 let server = Server::run();
874 let scopes = vec!["scope1".to_string()];
875 let response = MDSTokenResponse {
876 access_token: "test-access-token".to_string(),
877 expires_in: None,
878 token_type: "test-token-type".to_string(),
879 };
880 server.expect(
881 Expectation::matching(all_of![
882 request::path(format!("{MDS_DEFAULT_URI}/token")),
883 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
884 ])
885 .respond_with(json_encoded(response)),
886 );
887
888 let mdsc = Builder::default()
889 .with_endpoint(format!("http://{}", server.addr()))
890 .with_scopes(scopes)
891 .build()?;
892 let headers = mdsc.headers(Extensions::new()).await?;
893 assert_eq!(
894 get_token_from_headers(headers.clone()).unwrap(),
895 "test-access-token"
896 );
897 assert_eq!(
898 get_token_type_from_headers(headers).unwrap(),
899 "test-token-type"
900 );
901
902 Ok(())
903 }
904
905 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
906 #[parallel]
907 async fn credentials_headers_retryable_error() -> TestResult {
908 let server = Server::run();
909 let scopes = vec!["scope1".to_string()];
910 server.expect(
911 Expectation::matching(all_of![
912 request::path(format!("{MDS_DEFAULT_URI}/token")),
913 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
914 ])
915 .respond_with(status_code(503)),
916 );
917
918 let mdsc = Builder::default()
919 .with_endpoint(format!("http://{}", server.addr()))
920 .with_scopes(scopes)
921 .build()?;
922 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
923 assert!(err.is_transient());
924 let source = err
925 .source()
926 .and_then(|e| e.downcast_ref::<reqwest::Error>());
927 assert!(
928 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
929 "{err:?}"
930 );
931
932 Ok(())
933 }
934
935 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
936 #[parallel]
937 async fn credentials_headers_nonretryable_error() -> TestResult {
938 let server = Server::run();
939 let scopes = vec!["scope1".to_string()];
940 server.expect(
941 Expectation::matching(all_of![
942 request::path(format!("{MDS_DEFAULT_URI}/token")),
943 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
944 ])
945 .respond_with(status_code(401)),
946 );
947
948 let mdsc = Builder::default()
949 .with_endpoint(format!("http://{}", server.addr()))
950 .with_scopes(scopes)
951 .build()?;
952
953 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
954 assert!(!err.is_transient());
955 let source = err
956 .source()
957 .and_then(|e| e.downcast_ref::<reqwest::Error>());
958 assert!(
959 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
960 "{err:?}"
961 );
962
963 Ok(())
964 }
965
966 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
967 #[parallel]
968 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
969 let server = Server::run();
970 let scopes = vec!["scope1".to_string()];
971 server.expect(
972 Expectation::matching(all_of![
973 request::path(format!("{MDS_DEFAULT_URI}/token")),
974 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
975 ])
976 .respond_with(json_encoded("bad json")),
977 );
978
979 let mdsc = Builder::default()
980 .with_endpoint(format!("http://{}", server.addr()))
981 .with_scopes(scopes)
982 .build()?;
983
984 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
985 assert!(!e.is_transient());
986
987 Ok(())
988 }
989
990 #[tokio::test]
991 async fn get_default_universe_domain_success() -> TestResult {
992 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
993 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
994 Ok(())
995 }
996
997 #[tokio::test]
998 async fn get_custom_universe_domain_success() -> TestResult {
999 let universe_domain = "test-universe";
1000 let universe_domain_response = Builder::default()
1001 .with_universe_domain(universe_domain)
1002 .build()?
1003 .universe_domain()
1004 .await
1005 .unwrap();
1006 assert_eq!(universe_domain_response, universe_domain);
1007
1008 Ok(())
1009 }
1010}