1use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97pub(crate) const METADATA_FLAVOR_VALUE: &str = "Google";
98pub(crate) const METADATA_FLAVOR: &str = "metadata-flavor";
99pub(crate) const METADATA_ROOT: &str = "http://metadata.google.internal";
100pub(crate) const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101pub(crate) const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102const MDS_NOT_FOUND_ERROR: &str = concat!(
104 "Could not fetch an auth token to authenticate with Google Cloud. ",
105 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106 "and you have not configured local credentials for development and testing. ",
107 "To setup local credentials, run `gcloud auth application-default login`. ",
108 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114 T: CachedTokenProvider,
115{
116 quota_project_id: Option<String>,
117 token_provider: T,
118}
119
120#[derive(Debug, Default)]
132pub struct Builder {
133 endpoint: Option<String>,
134 quota_project_id: Option<String>,
135 scopes: Option<Vec<String>>,
136 created_by_adc: bool,
137 retry_builder: RetryTokenProviderBuilder,
138}
139
140impl Builder {
141 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
156 self.endpoint = Some(endpoint.into());
157 self
158 }
159
160 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
169 self.quota_project_id = Some(quota_project_id.into());
170 self
171 }
172
173 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
182 where
183 I: IntoIterator<Item = S>,
184 S: Into<String>,
185 {
186 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
187 self
188 }
189
190 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
205 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
206 self
207 }
208
209 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
225 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
226 self
227 }
228
229 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
250 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
251 self
252 }
253
254 pub(crate) fn from_adc() -> Self {
256 Self {
257 created_by_adc: true,
258 ..Default::default()
259 }
260 }
261
262 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
263 let final_endpoint: String;
264 let endpoint_overridden: bool;
265
266 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
268 final_endpoint = format!("http://{host_from_env}");
270 endpoint_overridden = true;
271 } else if let Some(builder_endpoint) = self.endpoint {
272 final_endpoint = builder_endpoint;
274 endpoint_overridden = true;
275 } else {
276 final_endpoint = METADATA_ROOT.to_string();
278 endpoint_overridden = false;
279 };
280
281 let tp = MDSAccessTokenProvider::builder()
282 .endpoint(final_endpoint)
283 .maybe_scopes(self.scopes)
284 .endpoint_overridden(endpoint_overridden)
285 .created_by_adc(self.created_by_adc)
286 .build();
287 self.retry_builder.build(tp)
288 }
289
290 pub fn build(self) -> BuildResult<Credentials> {
292 let mdsc = MDSCredentials {
293 quota_project_id: self.quota_project_id.clone(),
294 token_provider: TokenCache::new(self.build_token_provider()),
295 };
296 Ok(Credentials {
297 inner: Arc::new(mdsc),
298 })
299 }
300}
301
302#[async_trait::async_trait]
303impl<T> CredentialsProvider for MDSCredentials<T>
304where
305 T: CachedTokenProvider,
306{
307 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
308 let cached_token = self.token_provider.token(extensions).await?;
309 build_cacheable_headers(&cached_token, &self.quota_project_id)
310 }
311}
312
313#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
314struct MDSTokenResponse {
315 access_token: String,
316 #[serde(skip_serializing_if = "Option::is_none")]
317 expires_in: Option<u64>,
318 token_type: String,
319}
320
321#[derive(Debug, Clone, Default, Builder)]
322struct MDSAccessTokenProvider {
323 #[builder(into)]
324 scopes: Option<Vec<String>>,
325 #[builder(into)]
326 endpoint: String,
327 endpoint_overridden: bool,
328 created_by_adc: bool,
329}
330
331impl MDSAccessTokenProvider {
332 fn error_message(&self) -> &str {
340 if self.use_adc_message() {
341 MDS_NOT_FOUND_ERROR
342 } else {
343 "failed to fetch token"
344 }
345 }
346
347 fn use_adc_message(&self) -> bool {
348 self.created_by_adc && !self.endpoint_overridden
349 }
350}
351
352#[async_trait]
353impl TokenProvider for MDSAccessTokenProvider {
354 async fn token(&self) -> Result<Token> {
355 let client = Client::new();
356 let request = client
357 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
358 .header(
359 METADATA_FLAVOR,
360 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
361 );
362 let scopes = self.scopes.as_ref().map(|v| v.join(","));
365 let request = scopes
366 .into_iter()
367 .fold(request, |r, s| r.query(&[("scopes", s)]));
368
369 let response = request
374 .send()
375 .await
376 .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
377 if !response.status().is_success() {
379 let err = crate::errors::from_http_response(response, self.error_message()).await;
380 return Err(err);
381 }
382 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
383 CredentialsError::from_source(!e.is_decode(), e)
387 })?;
388 let token = Token {
389 token: response.access_token,
390 token_type: response.token_type,
391 expires_at: response
392 .expires_in
393 .map(|d| Instant::now() + Duration::from_secs(d)),
394 metadata: None,
395 };
396 Ok(token)
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
404 use crate::credentials::QUOTA_PROJECT_KEY;
405 use crate::credentials::tests::{
406 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
407 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
408 get_token_type_from_headers,
409 };
410 use crate::errors;
411 use crate::errors::CredentialsError;
412 use crate::token::tests::MockTokenProvider;
413 use http::HeaderValue;
414 use http::header::AUTHORIZATION;
415 use httptest::cycle;
416 use httptest::matchers::{all_of, contains, request, url_decoded};
417 use httptest::responders::{json_encoded, status_code};
418 use httptest::{Expectation, Server};
419 use reqwest::StatusCode;
420 use scoped_env::ScopedEnv;
421 use serial_test::{parallel, serial};
422 use std::error::Error;
423 use test_case::test_case;
424 use url::Url;
425
426 type TestResult = anyhow::Result<()>;
427
428 #[tokio::test]
429 #[parallel]
430 async fn test_mds_retries_on_transient_failures() -> TestResult {
431 let mut server = Server::run();
432 server.expect(
433 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
434 .times(3)
435 .respond_with(status_code(503)),
436 );
437
438 let provider = Builder::default()
439 .with_endpoint(format!("http://{}", server.addr()))
440 .with_retry_policy(get_mock_auth_retry_policy(3))
441 .with_backoff_policy(get_mock_backoff_policy())
442 .with_retry_throttler(get_mock_retry_throttler())
443 .build_token_provider();
444
445 let err = provider.token().await.unwrap_err();
446 assert!(!err.is_transient());
447 server.verify_and_clear();
448 Ok(())
449 }
450
451 #[tokio::test]
452 #[parallel]
453 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
454 let mut server = Server::run();
455 server.expect(
456 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
457 .times(1)
458 .respond_with(status_code(401)),
459 );
460
461 let provider = Builder::default()
462 .with_endpoint(format!("http://{}", server.addr()))
463 .with_retry_policy(get_mock_auth_retry_policy(1))
464 .with_backoff_policy(get_mock_backoff_policy())
465 .with_retry_throttler(get_mock_retry_throttler())
466 .build_token_provider();
467
468 let err = provider.token().await.unwrap_err();
469 assert!(!err.is_transient());
470 server.verify_and_clear();
471 Ok(())
472 }
473
474 #[tokio::test]
475 #[parallel]
476 async fn test_mds_retries_for_success() -> TestResult {
477 let mut server = Server::run();
478 let response = MDSTokenResponse {
479 access_token: "test-access-token".to_string(),
480 expires_in: Some(3600),
481 token_type: "test-token-type".to_string(),
482 };
483
484 server.expect(
485 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
486 .times(3)
487 .respond_with(cycle![
488 status_code(503).body("try-again"),
489 status_code(503).body("try-again"),
490 status_code(200)
491 .append_header("Content-Type", "application/json")
492 .body(serde_json::to_string(&response).unwrap()),
493 ]),
494 );
495
496 let provider = Builder::default()
497 .with_endpoint(format!("http://{}", server.addr()))
498 .with_retry_policy(get_mock_auth_retry_policy(3))
499 .with_backoff_policy(get_mock_backoff_policy())
500 .with_retry_throttler(get_mock_retry_throttler())
501 .build_token_provider();
502
503 let token = provider.token().await?;
504 assert_eq!(token.token, "test-access-token");
505
506 server.verify_and_clear();
507 Ok(())
508 }
509
510 #[test]
511 fn validate_default_endpoint_urls() {
512 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
513 assert!(default_endpoint_address.is_ok());
514
515 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
516 assert!(token_endpoint_address.is_ok());
517 }
518
519 #[tokio::test]
520 async fn headers_success() -> TestResult {
521 let token = Token {
522 token: "test-token".to_string(),
523 token_type: "Bearer".to_string(),
524 expires_at: None,
525 metadata: None,
526 };
527
528 let mut mock = MockTokenProvider::new();
529 mock.expect_token().times(1).return_once(|| Ok(token));
530
531 let mdsc = MDSCredentials {
532 quota_project_id: None,
533 token_provider: TokenCache::new(mock),
534 };
535
536 let mut extensions = Extensions::new();
537 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
538 let (headers, entity_tag) = match cached_headers {
539 CacheableResource::New { entity_tag, data } => (data, entity_tag),
540 CacheableResource::NotModified => unreachable!("expecting new headers"),
541 };
542 let token = headers.get(AUTHORIZATION).unwrap();
543 assert_eq!(headers.len(), 1, "{headers:?}");
544 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
545 assert!(token.is_sensitive());
546
547 extensions.insert(entity_tag);
548
549 let cached_headers = mdsc.headers(extensions).await?;
550
551 match cached_headers {
552 CacheableResource::New { .. } => unreachable!("expecting new headers"),
553 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
554 };
555 Ok(())
556 }
557
558 #[tokio::test]
559 async fn headers_failure() {
560 let mut mock = MockTokenProvider::new();
561 mock.expect_token()
562 .times(1)
563 .return_once(|| Err(errors::non_retryable_from_str("fail")));
564
565 let mdsc = MDSCredentials {
566 quota_project_id: None,
567 token_provider: TokenCache::new(mock),
568 };
569 assert!(mdsc.headers(Extensions::new()).await.is_err());
570 }
571
572 #[test]
573 fn error_message_with_adc() {
574 let provider = MDSAccessTokenProvider::builder()
575 .endpoint("http://127.0.0.1")
576 .created_by_adc(true)
577 .endpoint_overridden(false)
578 .build();
579
580 let want = MDS_NOT_FOUND_ERROR;
581 let got = provider.error_message();
582 assert!(got.contains(want), "{got}, {provider:?}");
583 }
584
585 #[test_case(false, false)]
586 #[test_case(false, true)]
587 #[test_case(true, true)]
588 fn error_message_without_adc(adc: bool, overridden: bool) {
589 let provider = MDSAccessTokenProvider::builder()
590 .endpoint("http://127.0.0.1")
591 .created_by_adc(adc)
592 .endpoint_overridden(overridden)
593 .build();
594
595 let not_want = MDS_NOT_FOUND_ERROR;
596 let got = provider.error_message();
597 assert!(!got.contains(not_want), "{got}, {provider:?}");
598 }
599
600 #[tokio::test]
601 #[serial]
602 async fn adc_no_mds() -> TestResult {
603 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
604 return Ok(());
606 };
607
608 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
609 assert!(
610 original_err.to_string().contains("application-default"),
611 "display={err}, debug={err:?}"
612 );
613
614 Ok(())
615 }
616
617 #[tokio::test]
618 #[serial]
619 async fn adc_overridden_mds() -> TestResult {
620 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
621
622 let err = Builder::from_adc()
623 .build_token_provider()
624 .token()
625 .await
626 .unwrap_err();
627
628 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
629
630 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
631 assert!(original_err.is_transient());
632 assert!(
633 !original_err.to_string().contains("application-default"),
634 "display={err}, debug={err:?}"
635 );
636 let source = find_source_error::<reqwest::Error>(&err);
637 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
638
639 Ok(())
640 }
641
642 #[tokio::test]
643 #[serial]
644 async fn builder_no_mds() -> TestResult {
645 let Err(e) = Builder::default().build_token_provider().token().await else {
646 return Ok(());
648 };
649
650 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
651 assert!(
652 !format!("{:?}", original_err.source()).contains("application-default"),
653 "{e:?}"
654 );
655
656 Ok(())
657 }
658
659 #[tokio::test]
660 #[serial]
661 async fn test_gce_metadata_host_env_var() -> TestResult {
662 let server = Server::run();
663 let scopes = ["scope1", "scope2"];
664 let response = MDSTokenResponse {
665 access_token: "test-access-token".to_string(),
666 expires_in: Some(3600),
667 token_type: "test-token-type".to_string(),
668 };
669 server.expect(
670 Expectation::matching(all_of![
671 request::path(format!("{MDS_DEFAULT_URI}/token")),
672 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
673 ])
674 .respond_with(json_encoded(response)),
675 );
676
677 let addr = server.addr().to_string();
678 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
679 let mdsc = Builder::default()
680 .with_scopes(["scope1", "scope2"])
681 .build()
682 .unwrap();
683 let headers = mdsc.headers(Extensions::new()).await.unwrap();
684 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
685
686 assert_eq!(
687 get_token_from_headers(headers).unwrap(),
688 "test-access-token"
689 );
690 Ok(())
691 }
692
693 #[tokio::test]
694 #[parallel]
695 async fn headers_success_with_quota_project() -> TestResult {
696 let server = Server::run();
697 let scopes = ["scope1", "scope2"];
698 let response = MDSTokenResponse {
699 access_token: "test-access-token".to_string(),
700 expires_in: Some(3600),
701 token_type: "test-token-type".to_string(),
702 };
703 server.expect(
704 Expectation::matching(all_of![
705 request::path(format!("{MDS_DEFAULT_URI}/token")),
706 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
707 ])
708 .respond_with(json_encoded(response)),
709 );
710
711 let mdsc = Builder::default()
712 .with_scopes(["scope1", "scope2"])
713 .with_endpoint(format!("http://{}", server.addr()))
714 .with_quota_project_id("test-project")
715 .build()?;
716
717 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
718 let token = headers.get(AUTHORIZATION).unwrap();
719 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
720
721 assert_eq!(headers.len(), 2, "{headers:?}");
722 assert_eq!(
723 token,
724 HeaderValue::from_static("test-token-type test-access-token")
725 );
726 assert!(token.is_sensitive());
727 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
728 assert!(!quota_project.is_sensitive());
729
730 Ok(())
731 }
732
733 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
734 #[parallel]
735 async fn token_caching() -> TestResult {
736 let mut server = Server::run();
737 let scopes = vec!["scope1".to_string()];
738 let response = MDSTokenResponse {
739 access_token: "test-access-token".to_string(),
740 expires_in: Some(3600),
741 token_type: "test-token-type".to_string(),
742 };
743 server.expect(
744 Expectation::matching(all_of![
745 request::path(format!("{MDS_DEFAULT_URI}/token")),
746 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
747 ])
748 .times(1)
749 .respond_with(json_encoded(response)),
750 );
751
752 let mdsc = Builder::default()
753 .with_scopes(scopes)
754 .with_endpoint(format!("http://{}", server.addr()))
755 .build()?;
756 let headers = mdsc.headers(Extensions::new()).await?;
757 assert_eq!(
758 get_token_from_headers(headers).unwrap(),
759 "test-access-token"
760 );
761 let headers = mdsc.headers(Extensions::new()).await?;
762 assert_eq!(
763 get_token_from_headers(headers).unwrap(),
764 "test-access-token"
765 );
766
767 server.verify_and_clear();
769
770 Ok(())
771 }
772
773 #[tokio::test(start_paused = true)]
774 #[parallel]
775 async fn token_provider_full() -> TestResult {
776 let server = Server::run();
777 let scopes = vec!["scope1".to_string()];
778 let response = MDSTokenResponse {
779 access_token: "test-access-token".to_string(),
780 expires_in: Some(3600),
781 token_type: "test-token-type".to_string(),
782 };
783 server.expect(
784 Expectation::matching(all_of![
785 request::path(format!("{MDS_DEFAULT_URI}/token")),
786 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
787 ])
788 .respond_with(json_encoded(response)),
789 );
790
791 let token = Builder::default()
792 .with_endpoint(format!("http://{}", server.addr()))
793 .with_scopes(scopes)
794 .build_token_provider()
795 .token()
796 .await?;
797
798 let now = tokio::time::Instant::now();
799 assert_eq!(token.token, "test-access-token");
800 assert_eq!(token.token_type, "test-token-type");
801 assert!(
802 token
803 .expires_at
804 .is_some_and(|d| d >= now + Duration::from_secs(3600))
805 );
806
807 Ok(())
808 }
809
810 #[tokio::test(start_paused = true)]
811 #[parallel]
812 async fn token_provider_full_no_scopes() -> TestResult {
813 let server = Server::run();
814 let response = MDSTokenResponse {
815 access_token: "test-access-token".to_string(),
816 expires_in: Some(3600),
817 token_type: "test-token-type".to_string(),
818 };
819 server.expect(
820 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
821 .respond_with(json_encoded(response)),
822 );
823
824 let token = Builder::default()
825 .with_endpoint(format!("http://{}", server.addr()))
826 .build_token_provider()
827 .token()
828 .await?;
829
830 let now = Instant::now();
831 assert_eq!(token.token, "test-access-token");
832 assert_eq!(token.token_type, "test-token-type");
833 assert!(
834 token
835 .expires_at
836 .is_some_and(|d| d == now + Duration::from_secs(3600))
837 );
838
839 Ok(())
840 }
841
842 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
843 #[parallel]
844 async fn credential_provider_full() -> TestResult {
845 let server = Server::run();
846 let scopes = vec!["scope1".to_string()];
847 let response = MDSTokenResponse {
848 access_token: "test-access-token".to_string(),
849 expires_in: None,
850 token_type: "test-token-type".to_string(),
851 };
852 server.expect(
853 Expectation::matching(all_of![
854 request::path(format!("{MDS_DEFAULT_URI}/token")),
855 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
856 ])
857 .respond_with(json_encoded(response)),
858 );
859
860 let mdsc = Builder::default()
861 .with_endpoint(format!("http://{}", server.addr()))
862 .with_scopes(scopes)
863 .build()?;
864 let headers = mdsc.headers(Extensions::new()).await?;
865 assert_eq!(
866 get_token_from_headers(headers.clone()).unwrap(),
867 "test-access-token"
868 );
869 assert_eq!(
870 get_token_type_from_headers(headers).unwrap(),
871 "test-token-type"
872 );
873
874 Ok(())
875 }
876
877 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
878 #[parallel]
879 async fn credentials_headers_retryable_error() -> TestResult {
880 let server = Server::run();
881 let scopes = vec!["scope1".to_string()];
882 server.expect(
883 Expectation::matching(all_of![
884 request::path(format!("{MDS_DEFAULT_URI}/token")),
885 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
886 ])
887 .respond_with(status_code(503)),
888 );
889
890 let mdsc = Builder::default()
891 .with_endpoint(format!("http://{}", server.addr()))
892 .with_scopes(scopes)
893 .build()?;
894 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
895 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
896 assert!(original_err.is_transient());
897 let source = find_source_error::<reqwest::Error>(&err);
898 assert!(
899 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
900 "{err:?}"
901 );
902
903 Ok(())
904 }
905
906 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
907 #[parallel]
908 async fn credentials_headers_nonretryable_error() -> TestResult {
909 let server = Server::run();
910 let scopes = vec!["scope1".to_string()];
911 server.expect(
912 Expectation::matching(all_of![
913 request::path(format!("{MDS_DEFAULT_URI}/token")),
914 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
915 ])
916 .respond_with(status_code(401)),
917 );
918
919 let mdsc = Builder::default()
920 .with_endpoint(format!("http://{}", server.addr()))
921 .with_scopes(scopes)
922 .build()?;
923
924 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
925 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
926 assert!(!original_err.is_transient());
927 let source = find_source_error::<reqwest::Error>(&err);
928 assert!(
929 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
930 "{err:?}"
931 );
932
933 Ok(())
934 }
935
936 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
937 #[parallel]
938 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
939 let server = Server::run();
940 let scopes = vec!["scope1".to_string()];
941 server.expect(
942 Expectation::matching(all_of![
943 request::path(format!("{MDS_DEFAULT_URI}/token")),
944 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
945 ])
946 .respond_with(json_encoded("bad json")),
947 );
948
949 let mdsc = Builder::default()
950 .with_endpoint(format!("http://{}", server.addr()))
951 .with_scopes(scopes)
952 .build()?;
953
954 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
955 assert!(!e.is_transient());
956
957 Ok(())
958 }
959
960 #[tokio::test]
961 async fn get_default_universe_domain_success() -> TestResult {
962 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
963 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
964 Ok(())
965 }
966}