1use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
78use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
79use crate::headers_util::build_cacheable_headers;
80use crate::mds::client::Client as MDSClient;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use gax::backoff_policy::BackoffPolicyArg;
87use gax::error::CredentialsError;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap};
91use std::default::Default;
92use std::sync::Arc;
93
94const MDS_NOT_FOUND_ERROR: &str = concat!(
96 "Could not fetch an auth token to authenticate with Google Cloud. ",
97 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
98 "and you have not configured local credentials for development and testing. ",
99 "To setup local credentials, run `gcloud auth application-default login`. ",
100 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
101);
102
103#[derive(Debug)]
104struct MDSCredentials<T>
105where
106 T: CachedTokenProvider,
107{
108 quota_project_id: Option<String>,
109 token_provider: T,
110}
111
112#[derive(Debug, Default)]
124pub struct Builder {
125 endpoint: Option<String>,
126 quota_project_id: Option<String>,
127 scopes: Option<Vec<String>>,
128 created_by_adc: bool,
129 retry_builder: RetryTokenProviderBuilder,
130}
131
132impl Builder {
133 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
148 self.endpoint = Some(endpoint.into());
149 self
150 }
151
152 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
161 self.quota_project_id = Some(quota_project_id.into());
162 self
163 }
164
165 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
174 where
175 I: IntoIterator<Item = S>,
176 S: Into<String>,
177 {
178 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
179 self
180 }
181
182 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
197 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
198 self
199 }
200
201 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
217 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
218 self
219 }
220
221 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
242 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
243 self
244 }
245
246 pub(crate) fn from_adc() -> Self {
248 Self {
249 created_by_adc: true,
250 ..Default::default()
251 }
252 }
253
254 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
255 let tp = MDSAccessTokenProvider::builder()
256 .endpoint(self.endpoint)
257 .maybe_scopes(self.scopes)
258 .created_by_adc(self.created_by_adc)
259 .build();
260 self.retry_builder.build(tp)
261 }
262
263 pub fn build(self) -> BuildResult<Credentials> {
265 Ok(self.build_access_token_credentials()?.into())
266 }
267
268 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
285 let mdsc = MDSCredentials {
286 quota_project_id: self.quota_project_id.clone(),
287 token_provider: TokenCache::new(self.build_token_provider()),
288 };
289 Ok(AccessTokenCredentials {
290 inner: Arc::new(mdsc),
291 })
292 }
293
294 pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
312 self.build_signer_with_iam_endpoint_override(None)
313 }
314
315 fn build_signer_with_iam_endpoint_override(
317 self,
318 iam_endpoint: Option<String>,
319 ) -> BuildResult<crate::signer::Signer> {
320 let client = MDSClient::new(self.endpoint.clone());
321 let credentials = self.build()?;
322 let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
323 let signing_provider = iam_endpoint
324 .iter()
325 .fold(signing_provider, |signing_provider, endpoint| {
326 signing_provider.with_iam_endpoint_override(endpoint)
327 });
328 Ok(crate::signer::Signer {
329 inner: Arc::new(signing_provider),
330 })
331 }
332}
333
334#[async_trait::async_trait]
335impl<T> CredentialsProvider for MDSCredentials<T>
336where
337 T: CachedTokenProvider,
338{
339 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
340 let cached_token = self.token_provider.token(extensions).await?;
341 build_cacheable_headers(&cached_token, &self.quota_project_id)
342 }
343}
344
345#[async_trait::async_trait]
346impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
347where
348 T: CachedTokenProvider,
349{
350 async fn access_token(&self) -> Result<AccessToken> {
351 let token = self.token_provider.token(Extensions::new()).await?;
352 token.into()
353 }
354}
355
356#[derive(Debug, Default)]
357struct MDSAccessTokenProviderBuilder {
358 scopes: Option<Vec<String>>,
359 endpoint: Option<String>,
360 created_by_adc: bool,
361}
362
363impl MDSAccessTokenProviderBuilder {
364 fn build(self) -> MDSAccessTokenProvider {
365 MDSAccessTokenProvider {
366 client: MDSClient::new(self.endpoint),
367 scopes: self.scopes,
368 created_by_adc: self.created_by_adc,
369 }
370 }
371
372 fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
373 self.scopes = v;
374 self
375 }
376
377 fn endpoint<T>(mut self, v: Option<T>) -> Self
378 where
379 T: Into<String>,
380 {
381 self.endpoint = v.map(Into::into);
382 self
383 }
384
385 fn created_by_adc(mut self, v: bool) -> Self {
386 self.created_by_adc = v;
387 self
388 }
389}
390
391#[derive(Debug, Clone)]
392struct MDSAccessTokenProvider {
393 scopes: Option<Vec<String>>,
394 client: MDSClient,
395 created_by_adc: bool,
396}
397
398impl MDSAccessTokenProvider {
399 fn builder() -> MDSAccessTokenProviderBuilder {
400 MDSAccessTokenProviderBuilder::default()
401 }
402
403 fn error_message(&self) -> &str {
411 if self.use_adc_message() {
412 MDS_NOT_FOUND_ERROR
413 } else {
414 "failed to fetch token"
415 }
416 }
417
418 fn use_adc_message(&self) -> bool {
419 self.created_by_adc && self.client.is_default_endpoint
420 }
421}
422
423#[async_trait]
424impl TokenProvider for MDSAccessTokenProvider {
425 async fn token(&self) -> Result<Token> {
426 self.client
427 .access_token(self.scopes.clone())
428 .await
429 .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
437 use crate::credentials::QUOTA_PROJECT_KEY;
438 use crate::credentials::tests::{
439 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
440 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
441 get_token_type_from_headers,
442 };
443 use crate::errors;
444 use crate::errors::CredentialsError;
445 use crate::mds::client::MDSTokenResponse;
446 use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
447 use crate::token::tests::MockTokenProvider;
448 use http::HeaderValue;
449 use http::header::AUTHORIZATION;
450 use httptest::cycle;
451 use httptest::matchers::{all_of, contains, request, url_decoded};
452 use httptest::responders::{json_encoded, status_code};
453 use httptest::{Expectation, Server};
454 use reqwest::StatusCode;
455 use scoped_env::ScopedEnv;
456 use serial_test::{parallel, serial};
457 use std::error::Error;
458 use std::time::Duration;
459 use test_case::test_case;
460 use tokio::time::Instant;
461 use url::Url;
462
463 type TestResult = anyhow::Result<()>;
464
465 #[tokio::test]
466 #[parallel]
467 async fn test_mds_retries_on_transient_failures() -> TestResult {
468 let mut server = Server::run();
469 server.expect(
470 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
471 .times(3)
472 .respond_with(status_code(503)),
473 );
474
475 let provider = Builder::default()
476 .with_endpoint(format!("http://{}", server.addr()))
477 .with_retry_policy(get_mock_auth_retry_policy(3))
478 .with_backoff_policy(get_mock_backoff_policy())
479 .with_retry_throttler(get_mock_retry_throttler())
480 .build_token_provider();
481
482 let err = provider.token().await.unwrap_err();
483 assert!(!err.is_transient());
484 server.verify_and_clear();
485 Ok(())
486 }
487
488 #[tokio::test]
489 #[parallel]
490 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
491 let mut server = Server::run();
492 server.expect(
493 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
494 .times(1)
495 .respond_with(status_code(401)),
496 );
497
498 let provider = Builder::default()
499 .with_endpoint(format!("http://{}", server.addr()))
500 .with_retry_policy(get_mock_auth_retry_policy(1))
501 .with_backoff_policy(get_mock_backoff_policy())
502 .with_retry_throttler(get_mock_retry_throttler())
503 .build_token_provider();
504
505 let err = provider.token().await.unwrap_err();
506 assert!(!err.is_transient());
507 server.verify_and_clear();
508 Ok(())
509 }
510
511 #[tokio::test]
512 #[parallel]
513 async fn test_mds_retries_for_success() -> TestResult {
514 let mut server = Server::run();
515 let response = MDSTokenResponse {
516 access_token: "test-access-token".to_string(),
517 expires_in: Some(3600),
518 token_type: "test-token-type".to_string(),
519 };
520
521 server.expect(
522 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
523 .times(3)
524 .respond_with(cycle![
525 status_code(503).body("try-again"),
526 status_code(503).body("try-again"),
527 status_code(200)
528 .append_header("Content-Type", "application/json")
529 .body(serde_json::to_string(&response).unwrap()),
530 ]),
531 );
532
533 let provider = Builder::default()
534 .with_endpoint(format!("http://{}", server.addr()))
535 .with_retry_policy(get_mock_auth_retry_policy(3))
536 .with_backoff_policy(get_mock_backoff_policy())
537 .with_retry_throttler(get_mock_retry_throttler())
538 .build_token_provider();
539
540 let token = provider.token().await?;
541 assert_eq!(token.token, "test-access-token");
542
543 server.verify_and_clear();
544 Ok(())
545 }
546
547 #[test]
548 #[parallel]
549 fn validate_default_endpoint_urls() {
550 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
551 assert!(default_endpoint_address.is_ok());
552
553 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
554 assert!(token_endpoint_address.is_ok());
555 }
556
557 #[tokio::test]
558 #[parallel]
559 async fn headers_success() -> TestResult {
560 let token = Token {
561 token: "test-token".to_string(),
562 token_type: "Bearer".to_string(),
563 expires_at: None,
564 metadata: None,
565 };
566
567 let mut mock = MockTokenProvider::new();
568 mock.expect_token().times(1).return_once(|| Ok(token));
569
570 let mdsc = MDSCredentials {
571 quota_project_id: None,
572 token_provider: TokenCache::new(mock),
573 };
574
575 let mut extensions = Extensions::new();
576 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
577 let (headers, entity_tag) = match cached_headers {
578 CacheableResource::New { entity_tag, data } => (data, entity_tag),
579 CacheableResource::NotModified => unreachable!("expecting new headers"),
580 };
581 let token = headers.get(AUTHORIZATION).unwrap();
582 assert_eq!(headers.len(), 1, "{headers:?}");
583 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
584 assert!(token.is_sensitive());
585
586 extensions.insert(entity_tag);
587
588 let cached_headers = mdsc.headers(extensions).await?;
589
590 match cached_headers {
591 CacheableResource::New { .. } => unreachable!("expecting new headers"),
592 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
593 };
594 Ok(())
595 }
596
597 #[tokio::test]
598 #[parallel]
599 async fn access_token_success() -> TestResult {
600 let server = Server::run();
601 let response = MDSTokenResponse {
602 access_token: "test-access-token".to_string(),
603 expires_in: Some(3600),
604 token_type: "Bearer".to_string(),
605 };
606 server.expect(
607 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
608 .respond_with(json_encoded(response)),
609 );
610
611 let creds = Builder::default()
612 .with_endpoint(format!("http://{}", server.addr()))
613 .build_access_token_credentials()
614 .unwrap();
615
616 let access_token = creds.access_token().await.unwrap();
617 assert_eq!(access_token.token, "test-access-token");
618
619 Ok(())
620 }
621
622 #[tokio::test]
623 #[parallel]
624 async fn headers_failure() {
625 let mut mock = MockTokenProvider::new();
626 mock.expect_token()
627 .times(1)
628 .return_once(|| Err(errors::non_retryable_from_str("fail")));
629
630 let mdsc = MDSCredentials {
631 quota_project_id: None,
632 token_provider: TokenCache::new(mock),
633 };
634 assert!(mdsc.headers(Extensions::new()).await.is_err());
635 }
636
637 #[test]
638 #[parallel]
639 fn error_message_with_adc() {
640 let provider = MDSAccessTokenProvider::builder()
641 .created_by_adc(true)
642 .build();
643
644 let want = MDS_NOT_FOUND_ERROR;
645 let got = provider.error_message();
646 assert!(got.contains(want), "{got}, {provider:?}");
647 }
648
649 #[test_case(false, false)]
650 #[test_case(false, true)]
651 #[test_case(true, true)]
652 fn error_message_without_adc(adc: bool, overridden: bool) {
653 let endpoint = if overridden {
654 Some("http://127.0.0.1")
655 } else {
656 None
657 };
658 let provider = MDSAccessTokenProvider::builder()
659 .endpoint(endpoint)
660 .created_by_adc(adc)
661 .build();
662
663 let not_want = MDS_NOT_FOUND_ERROR;
664 let got = provider.error_message();
665 assert!(!got.contains(not_want), "{got}, {provider:?}");
666 }
667
668 #[tokio::test]
669 #[serial]
670 async fn adc_no_mds() -> TestResult {
671 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
672 return Ok(());
674 };
675
676 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
677 assert!(
678 original_err.to_string().contains("application-default"),
679 "display={err}, debug={err:?}"
680 );
681
682 Ok(())
683 }
684
685 #[tokio::test]
686 #[serial]
687 async fn adc_overridden_mds() -> TestResult {
688 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
689
690 let err = Builder::from_adc()
691 .build_token_provider()
692 .token()
693 .await
694 .unwrap_err();
695
696 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
697
698 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
699 assert!(original_err.is_transient());
700 assert!(
701 !original_err.to_string().contains("application-default"),
702 "display={err}, debug={err:?}"
703 );
704 let source = find_source_error::<reqwest::Error>(&err);
705 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
706
707 Ok(())
708 }
709
710 #[tokio::test]
711 #[serial]
712 async fn builder_no_mds() -> TestResult {
713 let Err(e) = Builder::default().build_token_provider().token().await else {
714 return Ok(());
716 };
717
718 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
719 assert!(
720 !format!("{:?}", original_err.source()).contains("application-default"),
721 "{e:?}"
722 );
723
724 Ok(())
725 }
726
727 #[tokio::test]
728 #[serial]
729 async fn test_gce_metadata_host_env_var() -> TestResult {
730 let server = Server::run();
731 let scopes = ["scope1", "scope2"];
732 let response = MDSTokenResponse {
733 access_token: "test-access-token".to_string(),
734 expires_in: Some(3600),
735 token_type: "test-token-type".to_string(),
736 };
737 server.expect(
738 Expectation::matching(all_of![
739 request::path(format!("{MDS_DEFAULT_URI}/token")),
740 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
741 ])
742 .respond_with(json_encoded(response)),
743 );
744
745 let addr = server.addr().to_string();
746 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
747 let mdsc = Builder::default()
748 .with_scopes(["scope1", "scope2"])
749 .build()
750 .unwrap();
751 let headers = mdsc.headers(Extensions::new()).await.unwrap();
752 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
753
754 assert_eq!(
755 get_token_from_headers(headers).unwrap(),
756 "test-access-token"
757 );
758 Ok(())
759 }
760
761 #[tokio::test]
762 #[parallel]
763 async fn headers_success_with_quota_project() -> TestResult {
764 let server = Server::run();
765 let scopes = ["scope1", "scope2"];
766 let response = MDSTokenResponse {
767 access_token: "test-access-token".to_string(),
768 expires_in: Some(3600),
769 token_type: "test-token-type".to_string(),
770 };
771 server.expect(
772 Expectation::matching(all_of![
773 request::path(format!("{MDS_DEFAULT_URI}/token")),
774 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
775 ])
776 .respond_with(json_encoded(response)),
777 );
778
779 let mdsc = Builder::default()
780 .with_scopes(["scope1", "scope2"])
781 .with_endpoint(format!("http://{}", server.addr()))
782 .with_quota_project_id("test-project")
783 .build()?;
784
785 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
786 let token = headers.get(AUTHORIZATION).unwrap();
787 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
788
789 assert_eq!(headers.len(), 2, "{headers:?}");
790 assert_eq!(
791 token,
792 HeaderValue::from_static("test-token-type test-access-token")
793 );
794 assert!(token.is_sensitive());
795 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
796 assert!(!quota_project.is_sensitive());
797
798 Ok(())
799 }
800
801 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
802 #[parallel]
803 async fn token_caching() -> TestResult {
804 let mut server = Server::run();
805 let scopes = vec!["scope1".to_string()];
806 let response = MDSTokenResponse {
807 access_token: "test-access-token".to_string(),
808 expires_in: Some(3600),
809 token_type: "test-token-type".to_string(),
810 };
811 server.expect(
812 Expectation::matching(all_of![
813 request::path(format!("{MDS_DEFAULT_URI}/token")),
814 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
815 ])
816 .times(1)
817 .respond_with(json_encoded(response)),
818 );
819
820 let mdsc = Builder::default()
821 .with_scopes(scopes)
822 .with_endpoint(format!("http://{}", server.addr()))
823 .build()?;
824 let headers = mdsc.headers(Extensions::new()).await?;
825 assert_eq!(
826 get_token_from_headers(headers).unwrap(),
827 "test-access-token"
828 );
829 let headers = mdsc.headers(Extensions::new()).await?;
830 assert_eq!(
831 get_token_from_headers(headers).unwrap(),
832 "test-access-token"
833 );
834
835 server.verify_and_clear();
837
838 Ok(())
839 }
840
841 #[tokio::test(start_paused = true)]
842 #[parallel]
843 async fn token_provider_full() -> TestResult {
844 let server = Server::run();
845 let scopes = vec!["scope1".to_string()];
846 let response = MDSTokenResponse {
847 access_token: "test-access-token".to_string(),
848 expires_in: Some(3600),
849 token_type: "test-token-type".to_string(),
850 };
851 server.expect(
852 Expectation::matching(all_of![
853 request::path(format!("{MDS_DEFAULT_URI}/token")),
854 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
855 ])
856 .respond_with(json_encoded(response)),
857 );
858
859 let token = Builder::default()
860 .with_endpoint(format!("http://{}", server.addr()))
861 .with_scopes(scopes)
862 .build_token_provider()
863 .token()
864 .await?;
865
866 let now = tokio::time::Instant::now();
867 assert_eq!(token.token, "test-access-token");
868 assert_eq!(token.token_type, "test-token-type");
869 assert!(
870 token
871 .expires_at
872 .is_some_and(|d| d >= now + Duration::from_secs(3600))
873 );
874
875 Ok(())
876 }
877
878 #[tokio::test(start_paused = true)]
879 #[parallel]
880 async fn token_provider_full_no_scopes() -> TestResult {
881 let server = Server::run();
882 let response = MDSTokenResponse {
883 access_token: "test-access-token".to_string(),
884 expires_in: Some(3600),
885 token_type: "test-token-type".to_string(),
886 };
887 server.expect(
888 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
889 .respond_with(json_encoded(response)),
890 );
891
892 let token = Builder::default()
893 .with_endpoint(format!("http://{}", server.addr()))
894 .build_token_provider()
895 .token()
896 .await?;
897
898 let now = Instant::now();
899 assert_eq!(token.token, "test-access-token");
900 assert_eq!(token.token_type, "test-token-type");
901 assert!(
902 token
903 .expires_at
904 .is_some_and(|d| d == now + Duration::from_secs(3600))
905 );
906
907 Ok(())
908 }
909
910 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
911 #[parallel]
912 async fn credential_provider_full() -> TestResult {
913 let server = Server::run();
914 let scopes = vec!["scope1".to_string()];
915 let response = MDSTokenResponse {
916 access_token: "test-access-token".to_string(),
917 expires_in: None,
918 token_type: "test-token-type".to_string(),
919 };
920 server.expect(
921 Expectation::matching(all_of![
922 request::path(format!("{MDS_DEFAULT_URI}/token")),
923 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
924 ])
925 .respond_with(json_encoded(response)),
926 );
927
928 let mdsc = Builder::default()
929 .with_endpoint(format!("http://{}", server.addr()))
930 .with_scopes(scopes)
931 .build()?;
932 let headers = mdsc.headers(Extensions::new()).await?;
933 assert_eq!(
934 get_token_from_headers(headers.clone()).unwrap(),
935 "test-access-token"
936 );
937 assert_eq!(
938 get_token_type_from_headers(headers).unwrap(),
939 "test-token-type"
940 );
941
942 Ok(())
943 }
944
945 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
946 #[parallel]
947 async fn credentials_headers_retryable_error() -> TestResult {
948 let server = Server::run();
949 let scopes = vec!["scope1".to_string()];
950 server.expect(
951 Expectation::matching(all_of![
952 request::path(format!("{MDS_DEFAULT_URI}/token")),
953 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
954 ])
955 .respond_with(status_code(503)),
956 );
957
958 let mdsc = Builder::default()
959 .with_endpoint(format!("http://{}", server.addr()))
960 .with_scopes(scopes)
961 .build()?;
962 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
963 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
964 assert!(original_err.is_transient());
965 let source = find_source_error::<reqwest::Error>(&err);
966 assert!(
967 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
968 "{err:?}"
969 );
970
971 Ok(())
972 }
973
974 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
975 #[parallel]
976 async fn credentials_headers_nonretryable_error() -> TestResult {
977 let server = Server::run();
978 let scopes = vec!["scope1".to_string()];
979 server.expect(
980 Expectation::matching(all_of![
981 request::path(format!("{MDS_DEFAULT_URI}/token")),
982 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
983 ])
984 .respond_with(status_code(401)),
985 );
986
987 let mdsc = Builder::default()
988 .with_endpoint(format!("http://{}", server.addr()))
989 .with_scopes(scopes)
990 .build()?;
991
992 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
993 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
994 assert!(!original_err.is_transient());
995 let source = find_source_error::<reqwest::Error>(&err);
996 assert!(
997 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
998 "{err:?}"
999 );
1000
1001 Ok(())
1002 }
1003
1004 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1005 #[parallel]
1006 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1007 let server = Server::run();
1008 let scopes = vec!["scope1".to_string()];
1009 server.expect(
1010 Expectation::matching(all_of![
1011 request::path(format!("{MDS_DEFAULT_URI}/token")),
1012 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1013 ])
1014 .respond_with(json_encoded("bad json")),
1015 );
1016
1017 let mdsc = Builder::default()
1018 .with_endpoint(format!("http://{}", server.addr()))
1019 .with_scopes(scopes)
1020 .build()?;
1021
1022 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1023 assert!(!e.is_transient());
1024
1025 Ok(())
1026 }
1027
1028 #[tokio::test]
1029 #[parallel]
1030 async fn get_default_universe_domain_success() -> TestResult {
1031 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1032 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1033 Ok(())
1034 }
1035
1036 #[tokio::test]
1037 #[parallel]
1038 async fn get_mds_signer() -> TestResult {
1039 use base64::{Engine, prelude::BASE64_STANDARD};
1040 use serde_json::json;
1041
1042 let server = Server::run();
1043 server.expect(
1044 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1045 .respond_with(json_encoded(MDSTokenResponse {
1046 access_token: "test-access-token".to_string(),
1047 expires_in: None,
1048 token_type: "Bearer".to_string(),
1049 })),
1050 );
1051 server.expect(
1052 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1053 .respond_with(status_code(200).body("test-client-email")),
1054 );
1055 server.expect(
1056 Expectation::matching(all_of![
1057 request::method_path(
1058 "POST",
1059 "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1060 ),
1061 request::headers(contains(("authorization", "Bearer test-access-token"))),
1062 ])
1063 .respond_with(json_encoded(json!({
1064 "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1065 }))),
1066 );
1067
1068 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1069
1070 let signer = Builder::default()
1071 .with_endpoint(&endpoint)
1072 .build_signer_with_iam_endpoint_override(Some(endpoint))?;
1073
1074 let client_email = signer.client_email().await?;
1075 assert_eq!(client_email, "test-client-email");
1076
1077 let signature = signer.sign(b"test").await?;
1078 assert_eq!(signature.as_ref(), b"signed_blob");
1079
1080 Ok(())
1081 }
1082}