1use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
78use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
79use crate::headers_util::AuthHeadersBuilder;
80use crate::mds::client::Client as MDSClient;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use google_cloud_gax::backoff_policy::BackoffPolicyArg;
87use google_cloud_gax::error::CredentialsError;
88use google_cloud_gax::retry_policy::RetryPolicyArg;
89use google_cloud_gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap};
91use std::default::Default;
92use std::sync::Arc;
93
94const MDS_NOT_FOUND_ERROR: &str = concat!(
96 "Could not fetch an auth token to authenticate with Google Cloud. ",
97 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
98 "and you have not configured local credentials for development and testing. ",
99 "To setup local credentials, run `gcloud auth application-default login`. ",
100 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
101);
102
103#[derive(Debug)]
104struct MDSCredentials<T>
105where
106 T: CachedTokenProvider,
107{
108 quota_project_id: Option<String>,
109 token_provider: T,
110}
111
112#[derive(Debug, Default)]
124pub struct Builder {
125 endpoint: Option<String>,
126 quota_project_id: Option<String>,
127 scopes: Option<Vec<String>>,
128 created_by_adc: bool,
129 retry_builder: RetryTokenProviderBuilder,
130}
131
132impl Builder {
133 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
148 self.endpoint = Some(endpoint.into());
149 self
150 }
151
152 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
161 self.quota_project_id = Some(quota_project_id.into());
162 self
163 }
164
165 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
174 where
175 I: IntoIterator<Item = S>,
176 S: Into<String>,
177 {
178 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
179 self
180 }
181
182 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
197 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
198 self
199 }
200
201 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
217 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
218 self
219 }
220
221 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
242 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
243 self
244 }
245
246 pub(crate) fn from_adc() -> Self {
248 Self {
249 created_by_adc: true,
250 ..Default::default()
251 }
252 }
253
254 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
255 let tp = MDSAccessTokenProvider::builder()
256 .endpoint(self.endpoint)
257 .maybe_scopes(self.scopes)
258 .created_by_adc(self.created_by_adc)
259 .build();
260 self.retry_builder.build(tp)
261 }
262
263 pub fn build(self) -> BuildResult<Credentials> {
265 Ok(self.build_access_token_credentials()?.into())
266 }
267
268 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
285 let mdsc = MDSCredentials {
286 quota_project_id: self.quota_project_id.clone(),
287 token_provider: TokenCache::new(self.build_token_provider()),
288 };
289 Ok(AccessTokenCredentials {
290 inner: Arc::new(mdsc),
291 })
292 }
293
294 pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
312 self.build_signer_with_iam_endpoint_override(None)
313 }
314
315 fn build_signer_with_iam_endpoint_override(
317 self,
318 iam_endpoint: Option<String>,
319 ) -> BuildResult<crate::signer::Signer> {
320 let client = MDSClient::new(self.endpoint.clone());
321 let credentials = self.build()?;
322 let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
323 let signing_provider = iam_endpoint
324 .iter()
325 .fold(signing_provider, |signing_provider, endpoint| {
326 signing_provider.with_iam_endpoint_override(endpoint)
327 });
328 Ok(crate::signer::Signer {
329 inner: Arc::new(signing_provider),
330 })
331 }
332}
333
334#[async_trait::async_trait]
335impl<T> CredentialsProvider for MDSCredentials<T>
336where
337 T: CachedTokenProvider,
338{
339 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
340 let token = self.token_provider.token(extensions).await?;
341
342 AuthHeadersBuilder::new(&token)
343 .maybe_quota_project_id(self.quota_project_id.as_deref())
344 .build()
345 }
346}
347
348#[async_trait::async_trait]
349impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
350where
351 T: CachedTokenProvider,
352{
353 async fn access_token(&self) -> Result<AccessToken> {
354 let token = self.token_provider.token(Extensions::new()).await?;
355 token.into()
356 }
357}
358
359#[derive(Debug, Default)]
360struct MDSAccessTokenProviderBuilder {
361 scopes: Option<Vec<String>>,
362 endpoint: Option<String>,
363 created_by_adc: bool,
364}
365
366impl MDSAccessTokenProviderBuilder {
367 fn build(self) -> MDSAccessTokenProvider {
368 MDSAccessTokenProvider {
369 client: MDSClient::new(self.endpoint),
370 scopes: self.scopes,
371 created_by_adc: self.created_by_adc,
372 }
373 }
374
375 fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
376 self.scopes = v;
377 self
378 }
379
380 fn endpoint<T>(mut self, v: Option<T>) -> Self
381 where
382 T: Into<String>,
383 {
384 self.endpoint = v.map(Into::into);
385 self
386 }
387
388 fn created_by_adc(mut self, v: bool) -> Self {
389 self.created_by_adc = v;
390 self
391 }
392}
393
394#[derive(Debug, Clone)]
395struct MDSAccessTokenProvider {
396 scopes: Option<Vec<String>>,
397 client: MDSClient,
398 created_by_adc: bool,
399}
400
401impl MDSAccessTokenProvider {
402 fn builder() -> MDSAccessTokenProviderBuilder {
403 MDSAccessTokenProviderBuilder::default()
404 }
405
406 fn error_message(&self) -> &str {
414 if self.use_adc_message() {
415 MDS_NOT_FOUND_ERROR
416 } else {
417 "failed to fetch token"
418 }
419 }
420
421 fn use_adc_message(&self) -> bool {
422 self.created_by_adc && self.client.is_default_endpoint
423 }
424}
425
426#[async_trait]
427impl TokenProvider for MDSAccessTokenProvider {
428 async fn token(&self) -> Result<Token> {
429 self.client
430 .access_token(self.scopes.clone())
431 .await
432 .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
440 use crate::credentials::QUOTA_PROJECT_KEY;
441 use crate::credentials::tests::{
442 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
443 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
444 get_token_type_from_headers,
445 };
446 use crate::errors;
447 use crate::errors::CredentialsError;
448 use crate::mds::client::MDSTokenResponse;
449 use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
450 use crate::token::tests::MockTokenProvider;
451 use http::HeaderValue;
452 use http::header::AUTHORIZATION;
453 use httptest::cycle;
454 use httptest::matchers::{all_of, contains, request, url_decoded};
455 use httptest::responders::{json_encoded, status_code};
456 use httptest::{Expectation, Server};
457 use reqwest::StatusCode;
458 use scoped_env::ScopedEnv;
459 use serial_test::{parallel, serial};
460 use std::error::Error;
461 use std::time::Duration;
462 use test_case::test_case;
463 use tokio::time::Instant;
464 use url::Url;
465
466 type TestResult = anyhow::Result<()>;
467
468 #[tokio::test]
469 #[parallel]
470 async fn test_mds_retries_on_transient_failures() -> TestResult {
471 let mut server = Server::run();
472 server.expect(
473 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
474 .times(3)
475 .respond_with(status_code(503)),
476 );
477
478 let provider = Builder::default()
479 .with_endpoint(format!("http://{}", server.addr()))
480 .with_retry_policy(get_mock_auth_retry_policy(3))
481 .with_backoff_policy(get_mock_backoff_policy())
482 .with_retry_throttler(get_mock_retry_throttler())
483 .build_token_provider();
484
485 let err = provider.token().await.unwrap_err();
486 assert!(err.is_transient(), "{err:?}");
487 server.verify_and_clear();
488 Ok(())
489 }
490
491 #[tokio::test]
492 #[parallel]
493 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
494 let mut server = Server::run();
495 server.expect(
496 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
497 .times(1)
498 .respond_with(status_code(401)),
499 );
500
501 let provider = Builder::default()
502 .with_endpoint(format!("http://{}", server.addr()))
503 .with_retry_policy(get_mock_auth_retry_policy(1))
504 .with_backoff_policy(get_mock_backoff_policy())
505 .with_retry_throttler(get_mock_retry_throttler())
506 .build_token_provider();
507
508 let err = provider.token().await.unwrap_err();
509 assert!(!err.is_transient());
510 server.verify_and_clear();
511 Ok(())
512 }
513
514 #[tokio::test]
515 #[parallel]
516 async fn test_mds_retries_for_success() -> TestResult {
517 let mut server = Server::run();
518 let response = MDSTokenResponse {
519 access_token: "test-access-token".to_string(),
520 expires_in: Some(3600),
521 token_type: "test-token-type".to_string(),
522 };
523
524 server.expect(
525 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
526 .times(3)
527 .respond_with(cycle![
528 status_code(503).body("try-again"),
529 status_code(503).body("try-again"),
530 status_code(200)
531 .append_header("Content-Type", "application/json")
532 .body(serde_json::to_string(&response).unwrap()),
533 ]),
534 );
535
536 let provider = Builder::default()
537 .with_endpoint(format!("http://{}", server.addr()))
538 .with_retry_policy(get_mock_auth_retry_policy(3))
539 .with_backoff_policy(get_mock_backoff_policy())
540 .with_retry_throttler(get_mock_retry_throttler())
541 .build_token_provider();
542
543 let token = provider.token().await?;
544 assert_eq!(token.token, "test-access-token");
545
546 server.verify_and_clear();
547 Ok(())
548 }
549
550 #[test]
551 #[parallel]
552 fn validate_default_endpoint_urls() {
553 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
554 assert!(
555 default_endpoint_address.is_ok(),
556 "{default_endpoint_address:?}"
557 );
558
559 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
560 assert!(token_endpoint_address.is_ok(), "{token_endpoint_address:?}");
561 }
562
563 #[tokio::test]
564 #[parallel]
565 async fn headers_success() -> TestResult {
566 let token = Token {
567 token: "test-token".to_string(),
568 token_type: "Bearer".to_string(),
569 expires_at: None,
570 metadata: None,
571 };
572
573 let mut mock = MockTokenProvider::new();
574 mock.expect_token().times(1).return_once(|| Ok(token));
575
576 let mdsc = MDSCredentials {
577 quota_project_id: None,
578 token_provider: TokenCache::new(mock),
579 };
580
581 let mut extensions = Extensions::new();
582 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
583 let (headers, entity_tag) = match cached_headers {
584 CacheableResource::New { entity_tag, data } => (data, entity_tag),
585 CacheableResource::NotModified => unreachable!("expecting new headers"),
586 };
587 let token = headers.get(AUTHORIZATION).unwrap();
588 assert_eq!(headers.len(), 1, "{headers:?}");
589 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
590 assert!(token.is_sensitive());
591
592 extensions.insert(entity_tag);
593
594 let cached_headers = mdsc.headers(extensions).await?;
595
596 match cached_headers {
597 CacheableResource::New { .. } => unreachable!("expecting new headers"),
598 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
599 };
600 Ok(())
601 }
602
603 #[tokio::test]
604 #[parallel]
605 async fn access_token_success() -> TestResult {
606 let server = Server::run();
607 let response = MDSTokenResponse {
608 access_token: "test-access-token".to_string(),
609 expires_in: Some(3600),
610 token_type: "Bearer".to_string(),
611 };
612 server.expect(
613 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
614 .respond_with(json_encoded(response)),
615 );
616
617 let creds = Builder::default()
618 .with_endpoint(format!("http://{}", server.addr()))
619 .build_access_token_credentials()
620 .unwrap();
621
622 let access_token = creds.access_token().await.unwrap();
623 assert_eq!(access_token.token, "test-access-token");
624
625 Ok(())
626 }
627
628 #[tokio::test]
629 #[parallel]
630 async fn headers_failure() {
631 let mut mock = MockTokenProvider::new();
632 mock.expect_token()
633 .times(1)
634 .return_once(|| Err(errors::non_retryable_from_str("fail")));
635
636 let mdsc = MDSCredentials {
637 quota_project_id: None,
638 token_provider: TokenCache::new(mock),
639 };
640 let result = mdsc.headers(Extensions::new()).await;
641 assert!(result.is_err(), "{result:?}");
642 }
643
644 #[test]
645 #[parallel]
646 fn error_message_with_adc() {
647 let provider = MDSAccessTokenProvider::builder()
648 .created_by_adc(true)
649 .build();
650
651 let want = MDS_NOT_FOUND_ERROR;
652 let got = provider.error_message();
653 assert!(got.contains(want), "{got}, {provider:?}");
654 }
655
656 #[test_case(false, false)]
657 #[test_case(false, true)]
658 #[test_case(true, true)]
659 fn error_message_without_adc(adc: bool, overridden: bool) {
660 let endpoint = if overridden {
661 Some("http://127.0.0.1")
662 } else {
663 None
664 };
665 let provider = MDSAccessTokenProvider::builder()
666 .endpoint(endpoint)
667 .created_by_adc(adc)
668 .build();
669
670 let not_want = MDS_NOT_FOUND_ERROR;
671 let got = provider.error_message();
672 assert!(!got.contains(not_want), "{got}, {provider:?}");
673 }
674
675 #[tokio::test]
676 #[serial]
677 async fn adc_no_mds() -> TestResult {
678 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
679 return Ok(());
681 };
682
683 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
684 assert!(
685 original_err.to_string().contains("application-default"),
686 "display={err}, debug={err:?}"
687 );
688
689 Ok(())
690 }
691
692 #[tokio::test]
693 #[serial]
694 async fn adc_overridden_mds() -> TestResult {
695 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
696
697 let err = Builder::from_adc()
698 .build_token_provider()
699 .token()
700 .await
701 .unwrap_err();
702
703 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
704
705 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
706 assert!(original_err.is_transient());
707 assert!(
708 !original_err.to_string().contains("application-default"),
709 "display={err}, debug={err:?}"
710 );
711 let source = find_source_error::<reqwest::Error>(&err);
712 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
713
714 Ok(())
715 }
716
717 #[tokio::test]
718 #[serial]
719 async fn builder_no_mds() -> TestResult {
720 let Err(e) = Builder::default().build_token_provider().token().await else {
721 return Ok(());
723 };
724
725 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
726 assert!(
727 !format!("{:?}", original_err.source()).contains("application-default"),
728 "{e:?}"
729 );
730
731 Ok(())
732 }
733
734 #[tokio::test]
735 #[serial]
736 async fn test_gce_metadata_host_env_var() -> TestResult {
737 let server = Server::run();
738 let scopes = ["scope1", "scope2"];
739 let response = MDSTokenResponse {
740 access_token: "test-access-token".to_string(),
741 expires_in: Some(3600),
742 token_type: "test-token-type".to_string(),
743 };
744 server.expect(
745 Expectation::matching(all_of![
746 request::path(format!("{MDS_DEFAULT_URI}/token")),
747 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
748 ])
749 .respond_with(json_encoded(response)),
750 );
751
752 let addr = server.addr().to_string();
753 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
754 let mdsc = Builder::default()
755 .with_scopes(["scope1", "scope2"])
756 .build()
757 .unwrap();
758 let headers = mdsc.headers(Extensions::new()).await.unwrap();
759 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
760
761 assert_eq!(
762 get_token_from_headers(headers).unwrap(),
763 "test-access-token"
764 );
765 Ok(())
766 }
767
768 #[tokio::test]
769 #[parallel]
770 async fn headers_success_with_quota_project() -> TestResult {
771 let server = Server::run();
772 let scopes = ["scope1", "scope2"];
773 let response = MDSTokenResponse {
774 access_token: "test-access-token".to_string(),
775 expires_in: Some(3600),
776 token_type: "test-token-type".to_string(),
777 };
778 server.expect(
779 Expectation::matching(all_of![
780 request::path(format!("{MDS_DEFAULT_URI}/token")),
781 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
782 ])
783 .respond_with(json_encoded(response)),
784 );
785
786 let mdsc = Builder::default()
787 .with_scopes(["scope1", "scope2"])
788 .with_endpoint(format!("http://{}", server.addr()))
789 .with_quota_project_id("test-project")
790 .build()?;
791
792 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
793 let token = headers.get(AUTHORIZATION).unwrap();
794 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
795
796 assert_eq!(headers.len(), 2, "{headers:?}");
797 assert_eq!(
798 token,
799 HeaderValue::from_static("test-token-type test-access-token")
800 );
801 assert!(token.is_sensitive());
802 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
803 assert!(!quota_project.is_sensitive());
804
805 Ok(())
806 }
807
808 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
809 #[parallel]
810 async fn token_caching() -> TestResult {
811 let mut server = Server::run();
812 let scopes = vec!["scope1".to_string()];
813 let response = MDSTokenResponse {
814 access_token: "test-access-token".to_string(),
815 expires_in: Some(3600),
816 token_type: "test-token-type".to_string(),
817 };
818 server.expect(
819 Expectation::matching(all_of![
820 request::path(format!("{MDS_DEFAULT_URI}/token")),
821 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
822 ])
823 .times(1)
824 .respond_with(json_encoded(response)),
825 );
826
827 let mdsc = Builder::default()
828 .with_scopes(scopes)
829 .with_endpoint(format!("http://{}", server.addr()))
830 .build()?;
831 let headers = mdsc.headers(Extensions::new()).await?;
832 assert_eq!(
833 get_token_from_headers(headers).unwrap(),
834 "test-access-token"
835 );
836 let headers = mdsc.headers(Extensions::new()).await?;
837 assert_eq!(
838 get_token_from_headers(headers).unwrap(),
839 "test-access-token"
840 );
841
842 server.verify_and_clear();
844
845 Ok(())
846 }
847
848 #[tokio::test(start_paused = true)]
849 #[parallel]
850 async fn token_provider_full() -> TestResult {
851 let server = Server::run();
852 let scopes = vec!["scope1".to_string()];
853 let response = MDSTokenResponse {
854 access_token: "test-access-token".to_string(),
855 expires_in: Some(3600),
856 token_type: "test-token-type".to_string(),
857 };
858 server.expect(
859 Expectation::matching(all_of![
860 request::path(format!("{MDS_DEFAULT_URI}/token")),
861 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
862 ])
863 .respond_with(json_encoded(response)),
864 );
865
866 let token = Builder::default()
867 .with_endpoint(format!("http://{}", server.addr()))
868 .with_scopes(scopes)
869 .build_token_provider()
870 .token()
871 .await?;
872
873 let now = tokio::time::Instant::now();
874 assert_eq!(token.token, "test-access-token");
875 assert_eq!(token.token_type, "test-token-type");
876 assert!(
877 token
878 .expires_at
879 .is_some_and(|d| d >= now + Duration::from_secs(3600))
880 );
881
882 Ok(())
883 }
884
885 #[tokio::test(start_paused = true)]
886 #[parallel]
887 async fn token_provider_full_no_scopes() -> TestResult {
888 let server = Server::run();
889 let response = MDSTokenResponse {
890 access_token: "test-access-token".to_string(),
891 expires_in: Some(3600),
892 token_type: "test-token-type".to_string(),
893 };
894 server.expect(
895 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
896 .respond_with(json_encoded(response)),
897 );
898
899 let token = Builder::default()
900 .with_endpoint(format!("http://{}", server.addr()))
901 .build_token_provider()
902 .token()
903 .await?;
904
905 let now = Instant::now();
906 assert_eq!(token.token, "test-access-token");
907 assert_eq!(token.token_type, "test-token-type");
908 assert!(
909 token
910 .expires_at
911 .is_some_and(|d| d == now + Duration::from_secs(3600))
912 );
913
914 Ok(())
915 }
916
917 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
918 #[parallel]
919 async fn credential_provider_full() -> TestResult {
920 let server = Server::run();
921 let scopes = vec!["scope1".to_string()];
922 let response = MDSTokenResponse {
923 access_token: "test-access-token".to_string(),
924 expires_in: None,
925 token_type: "test-token-type".to_string(),
926 };
927 server.expect(
928 Expectation::matching(all_of![
929 request::path(format!("{MDS_DEFAULT_URI}/token")),
930 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
931 ])
932 .respond_with(json_encoded(response)),
933 );
934
935 let mdsc = Builder::default()
936 .with_endpoint(format!("http://{}", server.addr()))
937 .with_scopes(scopes)
938 .build()?;
939 let headers = mdsc.headers(Extensions::new()).await?;
940 assert_eq!(
941 get_token_from_headers(headers.clone()).unwrap(),
942 "test-access-token"
943 );
944 assert_eq!(
945 get_token_type_from_headers(headers).unwrap(),
946 "test-token-type"
947 );
948
949 Ok(())
950 }
951
952 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
953 #[parallel]
954 async fn credentials_headers_retryable_error() -> TestResult {
955 let server = Server::run();
956 let scopes = vec!["scope1".to_string()];
957 server.expect(
958 Expectation::matching(all_of![
959 request::path(format!("{MDS_DEFAULT_URI}/token")),
960 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
961 ])
962 .respond_with(status_code(503)),
963 );
964
965 let mdsc = Builder::default()
966 .with_endpoint(format!("http://{}", server.addr()))
967 .with_scopes(scopes)
968 .build()?;
969 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
970 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
971 assert!(original_err.is_transient());
972 let source = find_source_error::<reqwest::Error>(&err);
973 assert!(
974 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
975 "{err:?}"
976 );
977
978 Ok(())
979 }
980
981 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
982 #[parallel]
983 async fn credentials_headers_nonretryable_error() -> TestResult {
984 let server = Server::run();
985 let scopes = vec!["scope1".to_string()];
986 server.expect(
987 Expectation::matching(all_of![
988 request::path(format!("{MDS_DEFAULT_URI}/token")),
989 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
990 ])
991 .respond_with(status_code(401)),
992 );
993
994 let mdsc = Builder::default()
995 .with_endpoint(format!("http://{}", server.addr()))
996 .with_scopes(scopes)
997 .build()?;
998
999 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1000 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1001 assert!(!original_err.is_transient());
1002 let source = find_source_error::<reqwest::Error>(&err);
1003 assert!(
1004 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1005 "{err:?}"
1006 );
1007
1008 Ok(())
1009 }
1010
1011 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1012 #[parallel]
1013 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1014 let server = Server::run();
1015 let scopes = vec!["scope1".to_string()];
1016 server.expect(
1017 Expectation::matching(all_of![
1018 request::path(format!("{MDS_DEFAULT_URI}/token")),
1019 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1020 ])
1021 .respond_with(json_encoded("bad json")),
1022 );
1023
1024 let mdsc = Builder::default()
1025 .with_endpoint(format!("http://{}", server.addr()))
1026 .with_scopes(scopes)
1027 .build()?;
1028
1029 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1030 assert!(!e.is_transient());
1031
1032 Ok(())
1033 }
1034
1035 #[tokio::test]
1036 #[parallel]
1037 async fn get_default_universe_domain_success() -> TestResult {
1038 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1039 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1040 Ok(())
1041 }
1042
1043 #[tokio::test]
1044 #[parallel]
1045 async fn get_mds_signer() -> TestResult {
1046 use base64::{Engine, prelude::BASE64_STANDARD};
1047 use serde_json::json;
1048
1049 let server = Server::run();
1050 server.expect(
1051 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1052 .respond_with(json_encoded(MDSTokenResponse {
1053 access_token: "test-access-token".to_string(),
1054 expires_in: None,
1055 token_type: "Bearer".to_string(),
1056 })),
1057 );
1058 server.expect(
1059 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1060 .respond_with(status_code(200).body("test-client-email")),
1061 );
1062 server.expect(
1063 Expectation::matching(all_of![
1064 request::method_path(
1065 "POST",
1066 "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1067 ),
1068 request::headers(contains(("authorization", "Bearer test-access-token"))),
1069 ])
1070 .respond_with(json_encoded(json!({
1071 "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1072 }))),
1073 );
1074
1075 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1076
1077 let signer = Builder::default()
1078 .with_endpoint(&endpoint)
1079 .build_signer_with_iam_endpoint_override(Some(endpoint))?;
1080
1081 let client_email = signer.client_email().await?;
1082 assert_eq!(client_email, "test-client-email");
1083
1084 let signature = signer.sign(b"test").await?;
1085 assert_eq!(signature.as_ref(), b"signed_blob");
1086
1087 Ok(())
1088 }
1089}