1use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
78use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use gax::backoff_policy::BackoffPolicyArg;
87use gax::retry_policy::RetryPolicyArg;
88use gax::retry_throttler::RetryThrottlerArg;
89use http::{Extensions, HeaderMap, HeaderValue};
90use reqwest::Client;
91use std::default::Default;
92use std::sync::Arc;
93use std::time::Duration;
94use tokio::time::Instant;
95
96pub(crate) const METADATA_FLAVOR_VALUE: &str = "Google";
97pub(crate) const METADATA_FLAVOR: &str = "metadata-flavor";
98pub(crate) const METADATA_ROOT: &str = "http://metadata.google.internal";
99pub(crate) const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
100pub(crate) const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
101const MDS_NOT_FOUND_ERROR: &str = concat!(
103 "Could not fetch an auth token to authenticate with Google Cloud. ",
104 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
105 "and you have not configured local credentials for development and testing. ",
106 "To setup local credentials, run `gcloud auth application-default login`. ",
107 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
108);
109
110#[derive(Debug)]
111struct MDSCredentials<T>
112where
113 T: CachedTokenProvider,
114{
115 quota_project_id: Option<String>,
116 token_provider: T,
117}
118
119#[derive(Debug, Default)]
131pub struct Builder {
132 endpoint: Option<String>,
133 quota_project_id: Option<String>,
134 scopes: Option<Vec<String>>,
135 created_by_adc: bool,
136 retry_builder: RetryTokenProviderBuilder,
137}
138
139impl Builder {
140 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
155 self.endpoint = Some(endpoint.into());
156 self
157 }
158
159 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
168 self.quota_project_id = Some(quota_project_id.into());
169 self
170 }
171
172 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
181 where
182 I: IntoIterator<Item = S>,
183 S: Into<String>,
184 {
185 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
186 self
187 }
188
189 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
204 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
205 self
206 }
207
208 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
224 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
225 self
226 }
227
228 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
249 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
250 self
251 }
252
253 pub(crate) fn from_adc() -> Self {
255 Self {
256 created_by_adc: true,
257 ..Default::default()
258 }
259 }
260
261 fn resolve_endpoint(&self) -> (String, bool) {
262 let final_endpoint: String;
263 let endpoint_overridden: bool;
264 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
266 final_endpoint = format!("http://{host_from_env}");
268 endpoint_overridden = true;
269 } else if let Some(builder_endpoint) = self.endpoint.clone() {
270 final_endpoint = builder_endpoint;
272 endpoint_overridden = true;
273 } else {
274 final_endpoint = METADATA_ROOT.to_string();
276 endpoint_overridden = false;
277 };
278 (final_endpoint, endpoint_overridden)
279 }
280
281 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
282 let (final_endpoint, endpoint_overridden) = self.resolve_endpoint();
283
284 let tp = MDSAccessTokenProvider::builder()
285 .endpoint(final_endpoint)
286 .maybe_scopes(self.scopes)
287 .endpoint_overridden(endpoint_overridden)
288 .created_by_adc(self.created_by_adc)
289 .build();
290 self.retry_builder.build(tp)
291 }
292
293 pub fn build(self) -> BuildResult<Credentials> {
295 Ok(self.build_access_token_credentials()?.into())
296 }
297
298 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
315 let mdsc = MDSCredentials {
316 quota_project_id: self.quota_project_id.clone(),
317 token_provider: TokenCache::new(self.build_token_provider()),
318 };
319 Ok(AccessTokenCredentials {
320 inner: Arc::new(mdsc),
321 })
322 }
323
324 pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
342 self.build_signer_with_iam_endpoint_override(None)
343 }
344
345 fn build_signer_with_iam_endpoint_override(
347 self,
348 iam_endpoint: Option<String>,
349 ) -> BuildResult<crate::signer::Signer> {
350 let (endpoint, _) = self.resolve_endpoint();
351 let credentials = self.build()?;
352 let signing_provider = crate::signer::mds::MDSSigner::new(endpoint, credentials);
353 let signing_provider = iam_endpoint
354 .iter()
355 .fold(signing_provider, |signing_provider, endpoint| {
356 signing_provider.with_iam_endpoint_override(endpoint)
357 });
358 Ok(crate::signer::Signer {
359 inner: Arc::new(signing_provider),
360 })
361 }
362}
363
364#[async_trait::async_trait]
365impl<T> CredentialsProvider for MDSCredentials<T>
366where
367 T: CachedTokenProvider,
368{
369 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
370 let cached_token = self.token_provider.token(extensions).await?;
371 build_cacheable_headers(&cached_token, &self.quota_project_id)
372 }
373}
374
375#[async_trait::async_trait]
376impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
377where
378 T: CachedTokenProvider,
379{
380 async fn access_token(&self) -> Result<AccessToken> {
381 let token = self.token_provider.token(Extensions::new()).await?;
382 token.into()
383 }
384}
385
386#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
387struct MDSTokenResponse {
388 access_token: String,
389 #[serde(skip_serializing_if = "Option::is_none")]
390 expires_in: Option<u64>,
391 token_type: String,
392}
393
394#[derive(Debug, Default)]
395struct MDSAccessTokenProviderBuilder {
396 target: MDSAccessTokenProvider,
397}
398
399impl MDSAccessTokenProviderBuilder {
400 fn build(self) -> MDSAccessTokenProvider {
401 self.target
402 }
403
404 fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
405 self.target.scopes = v;
406 self
407 }
408
409 fn endpoint<T>(mut self, v: T) -> Self
410 where
411 T: Into<String>,
412 {
413 self.target.endpoint = v.into();
414 self
415 }
416
417 fn endpoint_overridden(mut self, v: bool) -> Self {
418 self.target.endpoint_overridden = v;
419 self
420 }
421
422 fn created_by_adc(mut self, v: bool) -> Self {
423 self.target.created_by_adc = v;
424 self
425 }
426}
427
428#[derive(Debug, Clone, Default)]
429struct MDSAccessTokenProvider {
430 scopes: Option<Vec<String>>,
431 endpoint: String,
432 endpoint_overridden: bool,
433 created_by_adc: bool,
434}
435
436impl MDSAccessTokenProvider {
437 fn builder() -> MDSAccessTokenProviderBuilder {
438 MDSAccessTokenProviderBuilder::default()
439 }
440
441 fn error_message(&self) -> &str {
449 if self.use_adc_message() {
450 MDS_NOT_FOUND_ERROR
451 } else {
452 "failed to fetch token"
453 }
454 }
455
456 fn use_adc_message(&self) -> bool {
457 self.created_by_adc && !self.endpoint_overridden
458 }
459}
460
461#[async_trait]
462impl TokenProvider for MDSAccessTokenProvider {
463 async fn token(&self) -> Result<Token> {
464 let client = Client::new();
465 let request = client
466 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
467 .header(
468 METADATA_FLAVOR,
469 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
470 );
471 let scopes = self.scopes.as_ref().map(|v| v.join(","));
474 let request = scopes
475 .into_iter()
476 .fold(request, |r, s| r.query(&[("scopes", s)]));
477
478 let response = request
483 .send()
484 .await
485 .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
486 if !response.status().is_success() {
488 let err = crate::errors::from_http_response(response, self.error_message()).await;
489 return Err(err);
490 }
491 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
492 CredentialsError::from_source(!e.is_decode(), e)
496 })?;
497 let token = Token {
498 token: response.access_token,
499 token_type: response.token_type,
500 expires_at: response
501 .expires_in
502 .map(|d| Instant::now() + Duration::from_secs(d)),
503 metadata: None,
504 };
505 Ok(token)
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
513 use crate::credentials::QUOTA_PROJECT_KEY;
514 use crate::credentials::tests::{
515 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
516 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
517 get_token_type_from_headers,
518 };
519 use crate::errors;
520 use crate::errors::CredentialsError;
521 use crate::token::tests::MockTokenProvider;
522 use http::HeaderValue;
523 use http::header::AUTHORIZATION;
524 use httptest::cycle;
525 use httptest::matchers::{all_of, contains, request, url_decoded};
526 use httptest::responders::{json_encoded, status_code};
527 use httptest::{Expectation, Server};
528 use reqwest::StatusCode;
529 use scoped_env::ScopedEnv;
530 use serial_test::{parallel, serial};
531 use std::error::Error;
532 use test_case::test_case;
533 use url::Url;
534
535 type TestResult = anyhow::Result<()>;
536
537 #[tokio::test]
538 #[parallel]
539 async fn test_mds_retries_on_transient_failures() -> TestResult {
540 let mut server = Server::run();
541 server.expect(
542 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
543 .times(3)
544 .respond_with(status_code(503)),
545 );
546
547 let provider = Builder::default()
548 .with_endpoint(format!("http://{}", server.addr()))
549 .with_retry_policy(get_mock_auth_retry_policy(3))
550 .with_backoff_policy(get_mock_backoff_policy())
551 .with_retry_throttler(get_mock_retry_throttler())
552 .build_token_provider();
553
554 let err = provider.token().await.unwrap_err();
555 assert!(!err.is_transient());
556 server.verify_and_clear();
557 Ok(())
558 }
559
560 #[tokio::test]
561 #[parallel]
562 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
563 let mut server = Server::run();
564 server.expect(
565 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
566 .times(1)
567 .respond_with(status_code(401)),
568 );
569
570 let provider = Builder::default()
571 .with_endpoint(format!("http://{}", server.addr()))
572 .with_retry_policy(get_mock_auth_retry_policy(1))
573 .with_backoff_policy(get_mock_backoff_policy())
574 .with_retry_throttler(get_mock_retry_throttler())
575 .build_token_provider();
576
577 let err = provider.token().await.unwrap_err();
578 assert!(!err.is_transient());
579 server.verify_and_clear();
580 Ok(())
581 }
582
583 #[tokio::test]
584 #[parallel]
585 async fn test_mds_retries_for_success() -> TestResult {
586 let mut server = Server::run();
587 let response = MDSTokenResponse {
588 access_token: "test-access-token".to_string(),
589 expires_in: Some(3600),
590 token_type: "test-token-type".to_string(),
591 };
592
593 server.expect(
594 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
595 .times(3)
596 .respond_with(cycle![
597 status_code(503).body("try-again"),
598 status_code(503).body("try-again"),
599 status_code(200)
600 .append_header("Content-Type", "application/json")
601 .body(serde_json::to_string(&response).unwrap()),
602 ]),
603 );
604
605 let provider = Builder::default()
606 .with_endpoint(format!("http://{}", server.addr()))
607 .with_retry_policy(get_mock_auth_retry_policy(3))
608 .with_backoff_policy(get_mock_backoff_policy())
609 .with_retry_throttler(get_mock_retry_throttler())
610 .build_token_provider();
611
612 let token = provider.token().await?;
613 assert_eq!(token.token, "test-access-token");
614
615 server.verify_and_clear();
616 Ok(())
617 }
618
619 #[test]
620 fn validate_default_endpoint_urls() {
621 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
622 assert!(default_endpoint_address.is_ok());
623
624 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
625 assert!(token_endpoint_address.is_ok());
626 }
627
628 #[tokio::test]
629 async fn headers_success() -> TestResult {
630 let token = Token {
631 token: "test-token".to_string(),
632 token_type: "Bearer".to_string(),
633 expires_at: None,
634 metadata: None,
635 };
636
637 let mut mock = MockTokenProvider::new();
638 mock.expect_token().times(1).return_once(|| Ok(token));
639
640 let mdsc = MDSCredentials {
641 quota_project_id: None,
642 token_provider: TokenCache::new(mock),
643 };
644
645 let mut extensions = Extensions::new();
646 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
647 let (headers, entity_tag) = match cached_headers {
648 CacheableResource::New { entity_tag, data } => (data, entity_tag),
649 CacheableResource::NotModified => unreachable!("expecting new headers"),
650 };
651 let token = headers.get(AUTHORIZATION).unwrap();
652 assert_eq!(headers.len(), 1, "{headers:?}");
653 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
654 assert!(token.is_sensitive());
655
656 extensions.insert(entity_tag);
657
658 let cached_headers = mdsc.headers(extensions).await?;
659
660 match cached_headers {
661 CacheableResource::New { .. } => unreachable!("expecting new headers"),
662 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
663 };
664 Ok(())
665 }
666
667 #[tokio::test]
668 async fn access_token_success() -> TestResult {
669 let server = Server::run();
670 let response = MDSTokenResponse {
671 access_token: "test-access-token".to_string(),
672 expires_in: Some(3600),
673 token_type: "Bearer".to_string(),
674 };
675 server.expect(
676 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
677 .respond_with(json_encoded(response)),
678 );
679
680 let creds = Builder::default()
681 .with_endpoint(format!("http://{}", server.addr()))
682 .build_access_token_credentials()
683 .unwrap();
684
685 let access_token = creds.access_token().await.unwrap();
686 assert_eq!(access_token.token, "test-access-token");
687
688 Ok(())
689 }
690
691 #[tokio::test]
692 async fn headers_failure() {
693 let mut mock = MockTokenProvider::new();
694 mock.expect_token()
695 .times(1)
696 .return_once(|| Err(errors::non_retryable_from_str("fail")));
697
698 let mdsc = MDSCredentials {
699 quota_project_id: None,
700 token_provider: TokenCache::new(mock),
701 };
702 assert!(mdsc.headers(Extensions::new()).await.is_err());
703 }
704
705 #[test]
706 fn error_message_with_adc() {
707 let provider = MDSAccessTokenProvider::builder()
708 .endpoint("http://127.0.0.1")
709 .created_by_adc(true)
710 .endpoint_overridden(false)
711 .build();
712
713 let want = MDS_NOT_FOUND_ERROR;
714 let got = provider.error_message();
715 assert!(got.contains(want), "{got}, {provider:?}");
716 }
717
718 #[test_case(false, false)]
719 #[test_case(false, true)]
720 #[test_case(true, true)]
721 fn error_message_without_adc(adc: bool, overridden: bool) {
722 let provider = MDSAccessTokenProvider::builder()
723 .endpoint("http://127.0.0.1")
724 .created_by_adc(adc)
725 .endpoint_overridden(overridden)
726 .build();
727
728 let not_want = MDS_NOT_FOUND_ERROR;
729 let got = provider.error_message();
730 assert!(!got.contains(not_want), "{got}, {provider:?}");
731 }
732
733 #[tokio::test]
734 #[serial]
735 async fn adc_no_mds() -> TestResult {
736 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
737 return Ok(());
739 };
740
741 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
742 assert!(
743 original_err.to_string().contains("application-default"),
744 "display={err}, debug={err:?}"
745 );
746
747 Ok(())
748 }
749
750 #[tokio::test]
751 #[serial]
752 async fn adc_overridden_mds() -> TestResult {
753 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
754
755 let err = Builder::from_adc()
756 .build_token_provider()
757 .token()
758 .await
759 .unwrap_err();
760
761 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
762
763 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
764 assert!(original_err.is_transient());
765 assert!(
766 !original_err.to_string().contains("application-default"),
767 "display={err}, debug={err:?}"
768 );
769 let source = find_source_error::<reqwest::Error>(&err);
770 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
771
772 Ok(())
773 }
774
775 #[tokio::test]
776 #[serial]
777 async fn builder_no_mds() -> TestResult {
778 let Err(e) = Builder::default().build_token_provider().token().await else {
779 return Ok(());
781 };
782
783 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
784 assert!(
785 !format!("{:?}", original_err.source()).contains("application-default"),
786 "{e:?}"
787 );
788
789 Ok(())
790 }
791
792 #[tokio::test]
793 #[serial]
794 async fn test_gce_metadata_host_env_var() -> TestResult {
795 let server = Server::run();
796 let scopes = ["scope1", "scope2"];
797 let response = MDSTokenResponse {
798 access_token: "test-access-token".to_string(),
799 expires_in: Some(3600),
800 token_type: "test-token-type".to_string(),
801 };
802 server.expect(
803 Expectation::matching(all_of![
804 request::path(format!("{MDS_DEFAULT_URI}/token")),
805 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
806 ])
807 .respond_with(json_encoded(response)),
808 );
809
810 let addr = server.addr().to_string();
811 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
812 let mdsc = Builder::default()
813 .with_scopes(["scope1", "scope2"])
814 .build()
815 .unwrap();
816 let headers = mdsc.headers(Extensions::new()).await.unwrap();
817 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
818
819 assert_eq!(
820 get_token_from_headers(headers).unwrap(),
821 "test-access-token"
822 );
823 Ok(())
824 }
825
826 #[tokio::test]
827 #[parallel]
828 async fn headers_success_with_quota_project() -> TestResult {
829 let server = Server::run();
830 let scopes = ["scope1", "scope2"];
831 let response = MDSTokenResponse {
832 access_token: "test-access-token".to_string(),
833 expires_in: Some(3600),
834 token_type: "test-token-type".to_string(),
835 };
836 server.expect(
837 Expectation::matching(all_of![
838 request::path(format!("{MDS_DEFAULT_URI}/token")),
839 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
840 ])
841 .respond_with(json_encoded(response)),
842 );
843
844 let mdsc = Builder::default()
845 .with_scopes(["scope1", "scope2"])
846 .with_endpoint(format!("http://{}", server.addr()))
847 .with_quota_project_id("test-project")
848 .build()?;
849
850 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
851 let token = headers.get(AUTHORIZATION).unwrap();
852 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
853
854 assert_eq!(headers.len(), 2, "{headers:?}");
855 assert_eq!(
856 token,
857 HeaderValue::from_static("test-token-type test-access-token")
858 );
859 assert!(token.is_sensitive());
860 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
861 assert!(!quota_project.is_sensitive());
862
863 Ok(())
864 }
865
866 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
867 #[parallel]
868 async fn token_caching() -> TestResult {
869 let mut server = Server::run();
870 let scopes = vec!["scope1".to_string()];
871 let response = MDSTokenResponse {
872 access_token: "test-access-token".to_string(),
873 expires_in: Some(3600),
874 token_type: "test-token-type".to_string(),
875 };
876 server.expect(
877 Expectation::matching(all_of![
878 request::path(format!("{MDS_DEFAULT_URI}/token")),
879 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
880 ])
881 .times(1)
882 .respond_with(json_encoded(response)),
883 );
884
885 let mdsc = Builder::default()
886 .with_scopes(scopes)
887 .with_endpoint(format!("http://{}", server.addr()))
888 .build()?;
889 let headers = mdsc.headers(Extensions::new()).await?;
890 assert_eq!(
891 get_token_from_headers(headers).unwrap(),
892 "test-access-token"
893 );
894 let headers = mdsc.headers(Extensions::new()).await?;
895 assert_eq!(
896 get_token_from_headers(headers).unwrap(),
897 "test-access-token"
898 );
899
900 server.verify_and_clear();
902
903 Ok(())
904 }
905
906 #[tokio::test(start_paused = true)]
907 #[parallel]
908 async fn token_provider_full() -> TestResult {
909 let server = Server::run();
910 let scopes = vec!["scope1".to_string()];
911 let response = MDSTokenResponse {
912 access_token: "test-access-token".to_string(),
913 expires_in: Some(3600),
914 token_type: "test-token-type".to_string(),
915 };
916 server.expect(
917 Expectation::matching(all_of![
918 request::path(format!("{MDS_DEFAULT_URI}/token")),
919 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
920 ])
921 .respond_with(json_encoded(response)),
922 );
923
924 let token = Builder::default()
925 .with_endpoint(format!("http://{}", server.addr()))
926 .with_scopes(scopes)
927 .build_token_provider()
928 .token()
929 .await?;
930
931 let now = tokio::time::Instant::now();
932 assert_eq!(token.token, "test-access-token");
933 assert_eq!(token.token_type, "test-token-type");
934 assert!(
935 token
936 .expires_at
937 .is_some_and(|d| d >= now + Duration::from_secs(3600))
938 );
939
940 Ok(())
941 }
942
943 #[tokio::test(start_paused = true)]
944 #[parallel]
945 async fn token_provider_full_no_scopes() -> TestResult {
946 let server = Server::run();
947 let response = MDSTokenResponse {
948 access_token: "test-access-token".to_string(),
949 expires_in: Some(3600),
950 token_type: "test-token-type".to_string(),
951 };
952 server.expect(
953 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
954 .respond_with(json_encoded(response)),
955 );
956
957 let token = Builder::default()
958 .with_endpoint(format!("http://{}", server.addr()))
959 .build_token_provider()
960 .token()
961 .await?;
962
963 let now = Instant::now();
964 assert_eq!(token.token, "test-access-token");
965 assert_eq!(token.token_type, "test-token-type");
966 assert!(
967 token
968 .expires_at
969 .is_some_and(|d| d == now + Duration::from_secs(3600))
970 );
971
972 Ok(())
973 }
974
975 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
976 #[parallel]
977 async fn credential_provider_full() -> TestResult {
978 let server = Server::run();
979 let scopes = vec!["scope1".to_string()];
980 let response = MDSTokenResponse {
981 access_token: "test-access-token".to_string(),
982 expires_in: None,
983 token_type: "test-token-type".to_string(),
984 };
985 server.expect(
986 Expectation::matching(all_of![
987 request::path(format!("{MDS_DEFAULT_URI}/token")),
988 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
989 ])
990 .respond_with(json_encoded(response)),
991 );
992
993 let mdsc = Builder::default()
994 .with_endpoint(format!("http://{}", server.addr()))
995 .with_scopes(scopes)
996 .build()?;
997 let headers = mdsc.headers(Extensions::new()).await?;
998 assert_eq!(
999 get_token_from_headers(headers.clone()).unwrap(),
1000 "test-access-token"
1001 );
1002 assert_eq!(
1003 get_token_type_from_headers(headers).unwrap(),
1004 "test-token-type"
1005 );
1006
1007 Ok(())
1008 }
1009
1010 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1011 #[parallel]
1012 async fn credentials_headers_retryable_error() -> TestResult {
1013 let server = Server::run();
1014 let scopes = vec!["scope1".to_string()];
1015 server.expect(
1016 Expectation::matching(all_of![
1017 request::path(format!("{MDS_DEFAULT_URI}/token")),
1018 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1019 ])
1020 .respond_with(status_code(503)),
1021 );
1022
1023 let mdsc = Builder::default()
1024 .with_endpoint(format!("http://{}", server.addr()))
1025 .with_scopes(scopes)
1026 .build()?;
1027 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1028 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1029 assert!(original_err.is_transient());
1030 let source = find_source_error::<reqwest::Error>(&err);
1031 assert!(
1032 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1033 "{err:?}"
1034 );
1035
1036 Ok(())
1037 }
1038
1039 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1040 #[parallel]
1041 async fn credentials_headers_nonretryable_error() -> TestResult {
1042 let server = Server::run();
1043 let scopes = vec!["scope1".to_string()];
1044 server.expect(
1045 Expectation::matching(all_of![
1046 request::path(format!("{MDS_DEFAULT_URI}/token")),
1047 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1048 ])
1049 .respond_with(status_code(401)),
1050 );
1051
1052 let mdsc = Builder::default()
1053 .with_endpoint(format!("http://{}", server.addr()))
1054 .with_scopes(scopes)
1055 .build()?;
1056
1057 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1058 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1059 assert!(!original_err.is_transient());
1060 let source = find_source_error::<reqwest::Error>(&err);
1061 assert!(
1062 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1063 "{err:?}"
1064 );
1065
1066 Ok(())
1067 }
1068
1069 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1070 #[parallel]
1071 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1072 let server = Server::run();
1073 let scopes = vec!["scope1".to_string()];
1074 server.expect(
1075 Expectation::matching(all_of![
1076 request::path(format!("{MDS_DEFAULT_URI}/token")),
1077 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1078 ])
1079 .respond_with(json_encoded("bad json")),
1080 );
1081
1082 let mdsc = Builder::default()
1083 .with_endpoint(format!("http://{}", server.addr()))
1084 .with_scopes(scopes)
1085 .build()?;
1086
1087 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1088 assert!(!e.is_transient());
1089
1090 Ok(())
1091 }
1092
1093 #[tokio::test]
1094 async fn get_default_universe_domain_success() -> TestResult {
1095 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1096 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1097 Ok(())
1098 }
1099
1100 #[tokio::test]
1101 async fn get_mds_signer() -> TestResult {
1102 use base64::{Engine, prelude::BASE64_STANDARD};
1103 use serde_json::json;
1104
1105 let server = Server::run();
1106 server.expect(
1107 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1108 .respond_with(json_encoded(MDSTokenResponse {
1109 access_token: "test-access-token".to_string(),
1110 expires_in: None,
1111 token_type: "Bearer".to_string(),
1112 })),
1113 );
1114 server.expect(
1115 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1116 .respond_with(status_code(200).body("test-client-email")),
1117 );
1118 server.expect(
1119 Expectation::matching(all_of![
1120 request::method_path(
1121 "POST",
1122 "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1123 ),
1124 request::headers(contains(("authorization", "Bearer test-access-token"))),
1125 ])
1126 .respond_with(json_encoded(json!({
1127 "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1128 }))),
1129 );
1130
1131 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1132
1133 let signer = Builder::default()
1134 .with_endpoint(&endpoint)
1135 .build_signer_with_iam_endpoint_override(Some(endpoint))?;
1136
1137 let client_email = signer.client_email().await?;
1138 assert_eq!(client_email, "test-client-email");
1139
1140 let signature = signer.sign(b"test").await?;
1141 assert_eq!(signature.as_ref(), b"signed_blob");
1142
1143 Ok(())
1144 }
1145}