1use crate::access_boundary::CredentialsWithAccessBoundary;
76use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
77use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
78use crate::headers_util::AuthHeadersBuilder;
79use crate::mds::client::Client as MDSClient;
80use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
81use crate::token::{CachedTokenProvider, Token, TokenProvider};
82use crate::token_cache::TokenCache;
83use crate::{BuildResult, Result};
84use async_trait::async_trait;
85use google_cloud_gax::backoff_policy::{BackoffPolicy, BackoffPolicyArg};
86use google_cloud_gax::error::CredentialsError;
87use google_cloud_gax::retry_policy::{RetryPolicy, RetryPolicyArg};
88use google_cloud_gax::retry_throttler::{RetryThrottlerArg, SharedRetryThrottler};
89use http::{Extensions, HeaderMap};
90use std::default::Default;
91use std::sync::{Arc, OnceLock};
92
93const MDS_NOT_FOUND_ERROR: &str = concat!(
95 "Could not fetch an auth token to authenticate with Google Cloud. ",
96 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
97 "and you have not configured local credentials for development and testing. ",
98 "To setup local credentials, run `gcloud auth application-default login`. ",
99 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
100);
101
102#[derive(Debug)]
103struct MDSCredentials<T>
104where
105 T: CachedTokenProvider,
106{
107 quota_project_id: Option<String>,
108 universe_domain_override: Option<String>,
109 universe_domain: OnceLock<Option<String>>,
110 token_provider: T,
111 mds_client: MDSClient,
112 retry_policy: Arc<dyn RetryPolicy>,
113 backoff_policy: Arc<dyn BackoffPolicy>,
114 retry_throttler: SharedRetryThrottler,
115}
116
117#[derive(Debug)]
129pub struct Builder {
130 endpoint: Option<String>,
131 quota_project_id: Option<String>,
132 universe_domain: Option<String>,
133 scopes: Option<Vec<String>>,
134 created_by_adc: bool,
135 retry_builder: RetryTokenProviderBuilder,
136 iam_endpoint_override: Option<String>,
137 is_access_boundary_enabled: bool,
138}
139
140impl Default for Builder {
141 fn default() -> Self {
142 Self {
143 endpoint: None,
144 quota_project_id: None,
145 universe_domain: None,
146 scopes: None,
147 created_by_adc: false,
148 retry_builder: RetryTokenProviderBuilder::default(),
149 iam_endpoint_override: None,
150 is_access_boundary_enabled: true,
151 }
152 }
153}
154
155impl Builder {
156 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
171 self.endpoint = Some(endpoint.into());
172 self
173 }
174
175 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
184 self.quota_project_id = Some(quota_project_id.into());
185 self
186 }
187
188 pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
204 self.universe_domain = Some(universe_domain.into());
205 self
206 }
207
208 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
217 where
218 I: IntoIterator<Item = S>,
219 S: Into<String>,
220 {
221 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
222 self
223 }
224
225 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
240 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
241 self
242 }
243
244 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
260 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
261 self
262 }
263
264 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
285 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
286 self
287 }
288
289 pub(crate) fn from_adc() -> Self {
291 Self {
292 created_by_adc: true,
293 ..Default::default()
294 }
295 }
296
297 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
298 let tp = MDSAccessTokenProvider::builder()
299 .endpoint(self.endpoint)
300 .maybe_scopes(self.scopes)
301 .created_by_adc(self.created_by_adc)
302 .build();
303 self.retry_builder.build(tp)
304 }
305
306 pub fn build(self) -> BuildResult<Credentials> {
308 Ok(self.build_credentials()?.into())
309 }
310
311 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
327 Ok(self.build_credentials()?.into())
328 }
329
330 fn build_credentials(
331 self,
332 ) -> BuildResult<CredentialsWithAccessBoundary<MDSCredentials<TokenCache>>> {
333 let iam_endpoint = self.iam_endpoint_override.clone();
334 let is_access_boundary_enabled = self.is_access_boundary_enabled;
335 let mds_client = MDSClient::new(self.endpoint.clone());
336 let retry_builder = self.retry_builder.clone();
337 let (backoff_policy, retry_throttler, retry_policy) = retry_builder.resolve();
338 let mdsc = MDSCredentials {
339 quota_project_id: self.quota_project_id.clone(),
340 universe_domain_override: self.universe_domain.clone(),
341 universe_domain: OnceLock::new(),
342 token_provider: TokenCache::new(self.build_token_provider()),
343 mds_client: mds_client.clone(),
344 backoff_policy,
345 retry_throttler,
346 retry_policy,
347 };
348 if !is_access_boundary_enabled {
349 return Ok(CredentialsWithAccessBoundary::new_no_op(mdsc));
350 }
351 Ok(CredentialsWithAccessBoundary::new_for_mds(
352 mdsc,
353 mds_client,
354 iam_endpoint,
355 ))
356 }
357
358 pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
375 let client = MDSClient::new(self.endpoint.clone());
376 let iam_endpoint = self.iam_endpoint_override.clone();
377 let credentials = self.build()?;
378 let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
379 let signing_provider = iam_endpoint
380 .iter()
381 .fold(signing_provider, |signing_provider, endpoint| {
382 signing_provider.with_iam_endpoint_override(endpoint)
383 });
384 Ok(crate::signer::Signer {
385 inner: Arc::new(signing_provider),
386 })
387 }
388}
389
390#[async_trait::async_trait]
391impl<T> CredentialsProvider for MDSCredentials<T>
392where
393 T: CachedTokenProvider,
394{
395 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
396 let token = self.token_provider.token(extensions).await?;
397
398 AuthHeadersBuilder::new(&token)
399 .maybe_quota_project_id(self.quota_project_id.as_deref())
400 .build()
401 }
402
403 async fn universe_domain(&self) -> Option<String> {
404 if let Some(ud) = &self.universe_domain_override {
405 return Some(ud.clone());
406 }
407 if let Some(ud) = self.universe_domain.get() {
408 return ud.clone();
409 }
410
411 let response = self
413 .mds_client
414 .universe_domain()
415 .with_backoff_policy(self.backoff_policy.clone().into())
416 .with_retry_policy(self.retry_policy.clone().into())
417 .with_retry_throttler(self.retry_throttler.clone().into())
418 .send()
419 .await;
420 match response {
421 Ok(universe_domain) => {
422 let _ = self.universe_domain.set(Some(universe_domain.clone()));
423 Some(universe_domain)
424 }
425 Err(e) => {
426 if !e.is_transient() {
427 let _ = self.universe_domain.set(None);
429 }
430 None
433 }
434 }
435 }
436}
437
438#[async_trait::async_trait]
439impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
440where
441 T: CachedTokenProvider,
442{
443 async fn access_token(&self) -> Result<AccessToken> {
444 let token = self.token_provider.token(Extensions::new()).await?;
445 token.into()
446 }
447}
448
449#[derive(Debug, Default)]
450struct MDSAccessTokenProviderBuilder {
451 scopes: Option<Vec<String>>,
452 endpoint: Option<String>,
453 created_by_adc: bool,
454}
455
456impl MDSAccessTokenProviderBuilder {
457 fn build(self) -> MDSAccessTokenProvider {
458 MDSAccessTokenProvider {
459 client: MDSClient::new(self.endpoint),
460 scopes: self.scopes,
461 created_by_adc: self.created_by_adc,
462 }
463 }
464
465 fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
466 self.scopes = v;
467 self
468 }
469
470 fn endpoint<T>(mut self, v: Option<T>) -> Self
471 where
472 T: Into<String>,
473 {
474 self.endpoint = v.map(Into::into);
475 self
476 }
477
478 fn created_by_adc(mut self, v: bool) -> Self {
479 self.created_by_adc = v;
480 self
481 }
482}
483
484#[derive(Debug, Clone)]
485struct MDSAccessTokenProvider {
486 scopes: Option<Vec<String>>,
487 client: MDSClient,
488 created_by_adc: bool,
489}
490
491impl MDSAccessTokenProvider {
492 fn builder() -> MDSAccessTokenProviderBuilder {
493 MDSAccessTokenProviderBuilder::default()
494 }
495
496 fn error_message(&self) -> &str {
504 if self.use_adc_message() {
505 MDS_NOT_FOUND_ERROR
506 } else {
507 "failed to fetch token"
508 }
509 }
510
511 fn use_adc_message(&self) -> bool {
512 self.created_by_adc && self.client.is_default_endpoint
513 }
514}
515
516#[async_trait]
517impl TokenProvider for MDSAccessTokenProvider {
518 async fn token(&self) -> Result<Token> {
519 self.client
520 .access_token(self.scopes.clone())
521 .send()
522 .await
523 .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::credentials::QUOTA_PROJECT_KEY;
531 use crate::credentials::tests::{
532 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
533 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
534 get_token_type_from_headers,
535 };
536 use crate::errors;
537 use crate::errors::CredentialsError;
538 use crate::mds::client::MDSTokenResponse;
539 use crate::mds::{
540 GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, MDS_UNIVERSE_DOMAIN_URI, METADATA_ROOT,
541 };
542 use crate::token::tests::MockTokenProvider;
543 use crate::token_cache::TokenCache;
544 use base64::{Engine, prelude::BASE64_STANDARD};
545 use http::HeaderValue;
546 use http::header::AUTHORIZATION;
547 use httptest::cycle;
548 use httptest::matchers::{all_of, contains, request, url_decoded};
549 use httptest::responders::{json_encoded, status_code};
550 use httptest::{Expectation, Server};
551 use reqwest::StatusCode;
552 use scoped_env::ScopedEnv;
553 use serde_json::json;
554 use serial_test::{parallel, serial};
555 use std::error::Error;
556 use std::time::Duration;
557 use test_case::test_case;
558 use tokio::time::Instant;
559 use url::Url;
560
561 type TestResult = anyhow::Result<()>;
562
563 impl Builder {
564 fn maybe_iam_endpoint_override(mut self, iam_endpoint_override: Option<String>) -> Self {
565 self.iam_endpoint_override = iam_endpoint_override;
566 self
567 }
568
569 fn without_access_boundary(mut self) -> Self {
570 self.is_access_boundary_enabled = false;
571 self
572 }
573 }
574
575 #[tokio::test]
576 #[parallel]
577 async fn test_mds_retries_on_transient_failures() -> TestResult {
578 let mut server = Server::run();
579 server.expect(
580 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
581 .times(3)
582 .respond_with(status_code(503)),
583 );
584
585 let provider = Builder::default()
586 .with_endpoint(format!("http://{}", server.addr()))
587 .with_retry_policy(get_mock_auth_retry_policy(3))
588 .with_backoff_policy(get_mock_backoff_policy())
589 .with_retry_throttler(get_mock_retry_throttler())
590 .build_token_provider();
591
592 let err = provider.token().await.unwrap_err();
593 assert!(err.is_transient(), "{err:?}");
594 server.verify_and_clear();
595 Ok(())
596 }
597
598 #[tokio::test]
599 #[parallel]
600 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
601 let mut server = Server::run();
602 server.expect(
603 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
604 .times(1)
605 .respond_with(status_code(401)),
606 );
607
608 let provider = Builder::default()
609 .with_endpoint(format!("http://{}", server.addr()))
610 .with_retry_policy(get_mock_auth_retry_policy(1))
611 .with_backoff_policy(get_mock_backoff_policy())
612 .with_retry_throttler(get_mock_retry_throttler())
613 .build_token_provider();
614
615 let err = provider.token().await.unwrap_err();
616 assert!(!err.is_transient());
617 server.verify_and_clear();
618 Ok(())
619 }
620
621 #[tokio::test]
622 #[parallel]
623 async fn test_mds_retries_for_success() -> TestResult {
624 let mut server = Server::run();
625 let response = MDSTokenResponse {
626 access_token: "test-access-token".to_string(),
627 expires_in: Some(3600),
628 token_type: "test-token-type".to_string(),
629 };
630
631 server.expect(
632 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
633 .times(3)
634 .respond_with(cycle![
635 status_code(503).body("try-again"),
636 status_code(503).body("try-again"),
637 status_code(200)
638 .append_header("Content-Type", "application/json")
639 .body(serde_json::to_string(&response).unwrap()),
640 ]),
641 );
642
643 let provider = Builder::default()
644 .with_endpoint(format!("http://{}", server.addr()))
645 .with_retry_policy(get_mock_auth_retry_policy(3))
646 .with_backoff_policy(get_mock_backoff_policy())
647 .with_retry_throttler(get_mock_retry_throttler())
648 .build_token_provider();
649
650 let token = provider.token().await?;
651 assert_eq!(token.token, "test-access-token");
652
653 server.verify_and_clear();
654 Ok(())
655 }
656
657 #[test]
658 #[parallel]
659 fn validate_default_endpoint_urls() {
660 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
661 assert!(
662 default_endpoint_address.is_ok(),
663 "{default_endpoint_address:?}"
664 );
665
666 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
667 assert!(token_endpoint_address.is_ok(), "{token_endpoint_address:?}");
668 }
669
670 #[tokio::test]
671 #[parallel]
672 async fn headers_success() -> TestResult {
673 let token = Token {
674 token: "test-token".to_string(),
675 token_type: "Bearer".to_string(),
676 expires_at: None,
677 metadata: None,
678 };
679
680 let mut mock = MockTokenProvider::new();
681 mock.expect_token().times(1).return_once(|| Ok(token));
682
683 let mdsc = MDSCredentials {
684 quota_project_id: None,
685 token_provider: TokenCache::new(mock),
686 universe_domain_override: None,
687 universe_domain: OnceLock::new(),
688 mds_client: MDSClient::new(None),
689 backoff_policy: Arc::new(get_mock_backoff_policy()),
690 retry_throttler: Arc::new(std::sync::Mutex::new(get_mock_retry_throttler())),
691 retry_policy: Arc::new(get_mock_auth_retry_policy(1)),
692 };
693
694 let mut extensions = Extensions::new();
695 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
696 let (headers, entity_tag) = match cached_headers {
697 CacheableResource::New { entity_tag, data } => (data, entity_tag),
698 CacheableResource::NotModified => unreachable!("expecting new headers"),
699 };
700 let token = headers.get(AUTHORIZATION).unwrap();
701 assert_eq!(headers.len(), 1, "{headers:?}");
702 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
703 assert!(token.is_sensitive());
704
705 extensions.insert(entity_tag);
706
707 let cached_headers = mdsc.headers(extensions).await?;
708
709 match cached_headers {
710 CacheableResource::New { .. } => unreachable!("expecting new headers"),
711 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
712 };
713 Ok(())
714 }
715
716 #[tokio::test]
717 #[parallel]
718 async fn access_token_success() -> TestResult {
719 let server = Server::run();
720 let response = MDSTokenResponse {
721 access_token: "test-access-token".to_string(),
722 expires_in: Some(3600),
723 token_type: "Bearer".to_string(),
724 };
725 server.expect(
726 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
727 .respond_with(json_encoded(response)),
728 );
729
730 let creds = Builder::default()
731 .with_endpoint(format!("http://{}", server.addr()))
732 .without_access_boundary()
733 .build_access_token_credentials()
734 .unwrap();
735
736 let access_token = creds.access_token().await.unwrap();
737 assert_eq!(access_token.token, "test-access-token");
738
739 Ok(())
740 }
741
742 #[tokio::test]
743 #[parallel]
744 async fn headers_failure() {
745 let mut mock = MockTokenProvider::new();
746 mock.expect_token()
747 .times(1)
748 .return_once(|| Err(errors::non_retryable_from_str("fail")));
749
750 let mdsc = MDSCredentials {
751 quota_project_id: None,
752 token_provider: TokenCache::new(mock),
753 universe_domain_override: None,
754 universe_domain: OnceLock::new(),
755 mds_client: MDSClient::new(None),
756 backoff_policy: Arc::new(get_mock_backoff_policy()),
757 retry_throttler: Arc::new(std::sync::Mutex::new(get_mock_retry_throttler())),
758 retry_policy: Arc::new(get_mock_auth_retry_policy(1)),
759 };
760 let result = mdsc.headers(Extensions::new()).await;
761 assert!(result.is_err(), "{result:?}");
762 }
763
764 #[test]
765 #[parallel]
766 fn error_message_with_adc() {
767 let provider = MDSAccessTokenProvider::builder()
768 .created_by_adc(true)
769 .build();
770
771 let want = MDS_NOT_FOUND_ERROR;
772 let got = provider.error_message();
773 assert!(got.contains(want), "{got}, {provider:?}");
774 }
775
776 #[test_case(false, false)]
777 #[test_case(false, true)]
778 #[test_case(true, true)]
779 fn error_message_without_adc(adc: bool, overridden: bool) {
780 let endpoint = if overridden {
781 Some("http://127.0.0.1")
782 } else {
783 None
784 };
785 let provider = MDSAccessTokenProvider::builder()
786 .endpoint(endpoint)
787 .created_by_adc(adc)
788 .build();
789
790 let not_want = MDS_NOT_FOUND_ERROR;
791 let got = provider.error_message();
792 assert!(!got.contains(not_want), "{got}, {provider:?}");
793 }
794
795 #[tokio::test]
796 #[serial]
797 async fn adc_no_mds() -> TestResult {
798 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
799 return Ok(());
801 };
802
803 let original_err = err.source().unwrap();
804 assert!(
805 original_err.to_string().contains("application-default"),
806 "display={err}, debug={err:?}"
807 );
808
809 Ok(())
810 }
811
812 #[tokio::test]
813 #[serial]
814 async fn adc_overridden_mds() -> TestResult {
815 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
816
817 let err = Builder::from_adc()
818 .build_token_provider()
819 .token()
820 .await
821 .unwrap_err();
822
823 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
824
825 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
826 assert!(original_err.is_transient());
827 assert!(
828 !original_err.to_string().contains("application-default"),
829 "display={err}, debug={err:?}"
830 );
831 let source = find_source_error::<reqwest::Error>(&err);
832 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
833
834 Ok(())
835 }
836
837 #[tokio::test]
838 #[serial]
839 async fn builder_no_mds() -> TestResult {
840 let Err(e) = Builder::default().build_token_provider().token().await else {
841 return Ok(());
843 };
844
845 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
846 assert!(
847 !format!("{:?}", original_err.source()).contains("application-default"),
848 "{e:?}"
849 );
850
851 Ok(())
852 }
853
854 #[tokio::test]
855 #[serial]
856 async fn test_gce_metadata_host_env_var() -> TestResult {
857 let server = Server::run();
858 let scopes = ["scope1", "scope2"];
859 let response = MDSTokenResponse {
860 access_token: "test-access-token".to_string(),
861 expires_in: Some(3600),
862 token_type: "test-token-type".to_string(),
863 };
864 server.expect(
865 Expectation::matching(all_of![
866 request::path(format!("{MDS_DEFAULT_URI}/token")),
867 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
868 ])
869 .respond_with(json_encoded(response)),
870 );
871
872 let addr = server.addr().to_string();
873 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
874 let mdsc = Builder::default()
875 .with_scopes(["scope1", "scope2"])
876 .without_access_boundary()
877 .build()
878 .unwrap();
879 let headers = mdsc.headers(Extensions::new()).await.unwrap();
880 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
881
882 assert_eq!(
883 get_token_from_headers(headers).unwrap(),
884 "test-access-token"
885 );
886 Ok(())
887 }
888
889 #[tokio::test]
890 #[parallel]
891 async fn headers_success_with_quota_project() -> TestResult {
892 let server = Server::run();
893 let scopes = ["scope1", "scope2"];
894 let response = MDSTokenResponse {
895 access_token: "test-access-token".to_string(),
896 expires_in: Some(3600),
897 token_type: "test-token-type".to_string(),
898 };
899 server.expect(
900 Expectation::matching(all_of![
901 request::path(format!("{MDS_DEFAULT_URI}/token")),
902 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
903 ])
904 .respond_with(json_encoded(response)),
905 );
906
907 let mdsc = Builder::default()
908 .with_scopes(["scope1", "scope2"])
909 .with_endpoint(format!("http://{}", server.addr()))
910 .with_quota_project_id("test-project")
911 .without_access_boundary()
912 .build()?;
913
914 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
915 let token = headers.get(AUTHORIZATION).unwrap();
916 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
917
918 assert_eq!(headers.len(), 2, "{headers:?}");
919 assert_eq!(
920 token,
921 HeaderValue::from_static("test-token-type test-access-token")
922 );
923 assert!(token.is_sensitive());
924 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
925 assert!(!quota_project.is_sensitive());
926
927 Ok(())
928 }
929
930 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
931 #[parallel]
932 async fn token_caching() -> TestResult {
933 let server = Server::run();
934 let scopes = vec!["scope1".to_string()];
935 let response = MDSTokenResponse {
936 access_token: "test-access-token".to_string(),
937 expires_in: Some(3600),
938 token_type: "test-token-type".to_string(),
939 };
940 server.expect(
941 Expectation::matching(all_of![
942 request::path(format!("{MDS_DEFAULT_URI}/token")),
943 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
944 ])
945 .times(1)
946 .respond_with(json_encoded(response)),
947 );
948
949 let mdsc = Builder::default()
950 .with_scopes(scopes)
951 .with_endpoint(format!("http://{}", server.addr()))
952 .without_access_boundary()
953 .build()?;
954 let headers = mdsc.headers(Extensions::new()).await?;
955 assert_eq!(
956 get_token_from_headers(headers).unwrap(),
957 "test-access-token"
958 );
959 let headers = mdsc.headers(Extensions::new()).await?;
960 assert_eq!(
961 get_token_from_headers(headers).unwrap(),
962 "test-access-token"
963 );
964
965 Ok(())
966 }
967
968 #[tokio::test(start_paused = true)]
969 #[parallel]
970 async fn token_provider_full() -> TestResult {
971 let server = Server::run();
972 let scopes = vec!["scope1".to_string()];
973 let response = MDSTokenResponse {
974 access_token: "test-access-token".to_string(),
975 expires_in: Some(3600),
976 token_type: "test-token-type".to_string(),
977 };
978 server.expect(
979 Expectation::matching(all_of![
980 request::path(format!("{MDS_DEFAULT_URI}/token")),
981 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
982 ])
983 .respond_with(json_encoded(response)),
984 );
985
986 let token = Builder::default()
987 .with_endpoint(format!("http://{}", server.addr()))
988 .with_scopes(scopes)
989 .build_token_provider()
990 .token()
991 .await?;
992
993 let now = tokio::time::Instant::now();
994 assert_eq!(token.token, "test-access-token");
995 assert_eq!(token.token_type, "test-token-type");
996 assert!(
997 token
998 .expires_at
999 .is_some_and(|d| d >= now + Duration::from_secs(3600))
1000 );
1001
1002 Ok(())
1003 }
1004
1005 #[tokio::test(start_paused = true)]
1006 #[parallel]
1007 async fn token_provider_full_no_scopes() -> TestResult {
1008 let server = Server::run();
1009 let response = MDSTokenResponse {
1010 access_token: "test-access-token".to_string(),
1011 expires_in: Some(3600),
1012 token_type: "test-token-type".to_string(),
1013 };
1014 server.expect(
1015 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
1016 .respond_with(json_encoded(response)),
1017 );
1018
1019 let token = Builder::default()
1020 .with_endpoint(format!("http://{}", server.addr()))
1021 .build_token_provider()
1022 .token()
1023 .await?;
1024
1025 let now = Instant::now();
1026 assert_eq!(token.token, "test-access-token");
1027 assert_eq!(token.token_type, "test-token-type");
1028 assert!(
1029 token
1030 .expires_at
1031 .is_some_and(|d| d == now + Duration::from_secs(3600))
1032 );
1033
1034 Ok(())
1035 }
1036
1037 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1038 #[parallel]
1039 async fn credential_provider_full() -> TestResult {
1040 let server = Server::run();
1041 let scopes = vec!["scope1".to_string()];
1042 let response = MDSTokenResponse {
1043 access_token: "test-access-token".to_string(),
1044 expires_in: None,
1045 token_type: "test-token-type".to_string(),
1046 };
1047 server.expect(
1048 Expectation::matching(all_of![
1049 request::path(format!("{MDS_DEFAULT_URI}/token")),
1050 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1051 ])
1052 .respond_with(json_encoded(response)),
1053 );
1054
1055 let mdsc = Builder::default()
1056 .with_endpoint(format!("http://{}", server.addr()))
1057 .with_scopes(scopes)
1058 .without_access_boundary()
1059 .build()?;
1060 let headers = mdsc.headers(Extensions::new()).await?;
1061 assert_eq!(
1062 get_token_from_headers(headers.clone()).unwrap(),
1063 "test-access-token"
1064 );
1065 assert_eq!(
1066 get_token_type_from_headers(headers).unwrap(),
1067 "test-token-type"
1068 );
1069
1070 Ok(())
1071 }
1072
1073 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1074 #[parallel]
1075 async fn credentials_headers_retryable_error() -> TestResult {
1076 let server = Server::run();
1077 let scopes = vec!["scope1".to_string()];
1078 server.expect(
1079 Expectation::matching(all_of![
1080 request::path(format!("{MDS_DEFAULT_URI}/token")),
1081 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1082 ])
1083 .respond_with(status_code(503)),
1084 );
1085
1086 let mdsc = Builder::default()
1087 .with_endpoint(format!("http://{}", server.addr()))
1088 .with_scopes(scopes)
1089 .without_access_boundary()
1090 .build()?;
1091 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1092 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1093 assert!(original_err.is_transient());
1094 let source = find_source_error::<google_cloud_gax::error::Error>(&err);
1095 assert!(
1096 matches!(source, Some(e) if e.http_status_code() == Some(StatusCode::SERVICE_UNAVAILABLE.into())),
1097 "{err:?}"
1098 );
1099
1100 Ok(())
1101 }
1102
1103 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1104 #[parallel]
1105 async fn credentials_headers_nonretryable_error() -> TestResult {
1106 let server = Server::run();
1107 let scopes = vec!["scope1".to_string()];
1108 server.expect(
1109 Expectation::matching(all_of![
1110 request::path(format!("{MDS_DEFAULT_URI}/token")),
1111 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1112 ])
1113 .respond_with(status_code(401)),
1114 );
1115
1116 let mdsc = Builder::default()
1117 .with_endpoint(format!("http://{}", server.addr()))
1118 .with_scopes(scopes)
1119 .without_access_boundary()
1120 .build()?;
1121
1122 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1123 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1124 assert!(!original_err.is_transient());
1125 let source = find_source_error::<google_cloud_gax::error::Error>(&err);
1126 assert!(
1127 matches!(source, Some(e) if e.http_status_code() == Some(StatusCode::UNAUTHORIZED.into())),
1128 "{err:?}"
1129 );
1130
1131 Ok(())
1132 }
1133
1134 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1135 #[parallel]
1136 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1137 let server = Server::run();
1138 let scopes = vec!["scope1".to_string()];
1139 server.expect(
1140 Expectation::matching(all_of![
1141 request::path(format!("{MDS_DEFAULT_URI}/token")),
1142 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1143 ])
1144 .respond_with(json_encoded("bad json")),
1145 );
1146
1147 let mdsc = Builder::default()
1148 .with_endpoint(format!("http://{}", server.addr()))
1149 .with_scopes(scopes)
1150 .without_access_boundary()
1151 .build()?;
1152
1153 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1154 assert!(!e.is_transient());
1155
1156 Ok(())
1157 }
1158
1159 #[tokio::test]
1160 #[parallel]
1161 async fn get_default_universe_domain() -> TestResult {
1162 let server = Server::run();
1163 server.expect(
1164 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1165 .respond_with(status_code(404)),
1166 );
1167
1168 let mut mock = MockTokenProvider::new();
1169 mock.expect_token()
1170 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1171
1172 let creds = MDSCredentials {
1173 quota_project_id: None,
1174 universe_domain_override: None,
1175 universe_domain: OnceLock::new(),
1176 token_provider: TokenCache::new(mock),
1177 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1178 backoff_policy: Arc::new(get_mock_backoff_policy()),
1179 retry_throttler: Arc::new(std::sync::Mutex::new(get_mock_retry_throttler())),
1180 retry_policy: Arc::new(get_mock_auth_retry_policy(1)),
1181 };
1182
1183 let universe_domain = creds.universe_domain().await;
1184 assert!(universe_domain.is_none());
1185 Ok(())
1186 }
1187
1188 #[tokio::test]
1189 #[parallel]
1190 async fn get_universe_domain_override() -> TestResult {
1191 let creds = Builder::default()
1192 .with_universe_domain("my-universe-domain.com")
1193 .without_access_boundary()
1194 .build()?;
1195 let universe_domain = creds.universe_domain().await;
1196 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1197 Ok(())
1198 }
1199
1200 #[tokio::test]
1201 #[parallel]
1202 async fn get_universe_domain_from_mds() -> TestResult {
1203 let server = Server::run();
1204 server.expect(
1205 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1206 .respond_with(status_code(200).body("my-universe-domain.com")),
1207 );
1208
1209 let mut mock = MockTokenProvider::new();
1210 mock.expect_token()
1211 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1212
1213 let creds = MDSCredentials {
1214 quota_project_id: None,
1215 universe_domain_override: None,
1216 universe_domain: OnceLock::new(),
1217 token_provider: TokenCache::new(mock),
1218 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1219 backoff_policy: Arc::new(get_mock_backoff_policy()),
1220 retry_throttler: Arc::new(std::sync::Mutex::new(get_mock_retry_throttler())),
1221 retry_policy: Arc::new(get_mock_auth_retry_policy(1)),
1222 };
1223 let universe_domain = creds.universe_domain().await;
1224 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1225 Ok(())
1226 }
1227
1228 #[tokio::test]
1229 #[parallel]
1230 async fn get_universe_domain_retries_on_transient_failures() -> TestResult {
1231 let server = Server::run();
1232 server.expect(
1233 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1234 .times(3)
1235 .respond_with(cycle![
1236 status_code(503).body("transient error"),
1237 status_code(503).body("transient error"),
1238 status_code(200).body("my-universe-domain.com"),
1239 ]),
1240 );
1241
1242 let mut mock = MockTokenProvider::new();
1243 mock.expect_token()
1244 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1245
1246 let creds = MDSCredentials {
1247 quota_project_id: None,
1248 universe_domain_override: None,
1249 universe_domain: OnceLock::new(),
1250 token_provider: TokenCache::new(mock),
1251 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1252 backoff_policy: Arc::new(get_mock_backoff_policy()),
1253 retry_throttler: Arc::new(std::sync::Mutex::new(get_mock_retry_throttler())),
1254 retry_policy: Arc::new(get_mock_auth_retry_policy(3)),
1255 };
1256
1257 let universe_domain = creds.universe_domain().await;
1258 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1259
1260 Ok(())
1261 }
1262
1263 #[tokio::test]
1264 #[parallel]
1265 async fn get_universe_domain_caching() -> TestResult {
1266 let server = Server::run();
1267 server.expect(
1268 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1269 .times(2)
1270 .respond_with(cycle![
1271 status_code(503).body("transient error"),
1272 status_code(200).body("my-universe-domain.com"),
1273 ]),
1274 );
1275
1276 let mut mock = MockTokenProvider::new();
1277 mock.expect_token()
1278 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1279
1280 let creds = MDSCredentials {
1281 quota_project_id: None,
1282 universe_domain_override: None,
1283 universe_domain: OnceLock::new(),
1284 token_provider: TokenCache::new(mock),
1285 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1286 backoff_policy: Arc::new(get_mock_backoff_policy()),
1287 retry_throttler: Arc::new(std::sync::Mutex::new(get_mock_retry_throttler())),
1288 retry_policy: Arc::new(get_mock_auth_retry_policy(1)),
1289 };
1290
1291 let universe_domain = creds.universe_domain().await;
1292 assert_eq!(universe_domain, None);
1293
1294 let universe_domain = creds.universe_domain().await;
1295 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1296
1297 let universe_domain = creds.universe_domain().await;
1298 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1299
1300 Ok(())
1301 }
1302
1303 #[tokio::test]
1304 #[parallel]
1305 async fn get_universe_domain_caching_permanent_error() -> TestResult {
1306 let server = Server::run();
1307 server.expect(
1308 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1309 .times(1)
1310 .respond_with(status_code(404).body("permanent error")),
1311 );
1312
1313 let mut mock = MockTokenProvider::new();
1314 mock.expect_token()
1315 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1316
1317 let creds = MDSCredentials {
1318 quota_project_id: None,
1319 universe_domain_override: None,
1320 universe_domain: OnceLock::new(),
1321 token_provider: TokenCache::new(mock),
1322 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1323 backoff_policy: Arc::new(get_mock_backoff_policy()),
1324 retry_throttler: Arc::new(std::sync::Mutex::new(get_mock_retry_throttler())),
1325 retry_policy: Arc::new(get_mock_auth_retry_policy(1)),
1326 };
1327
1328 let universe_domain = creds.universe_domain().await;
1329 assert_eq!(universe_domain, None);
1330
1331 let universe_domain = creds.universe_domain().await;
1332 assert_eq!(universe_domain, None);
1333
1334 Ok(())
1335 }
1336
1337 #[tokio::test]
1338 #[parallel]
1339 async fn get_mds_signer() -> TestResult {
1340 let server = Server::run();
1341 server.expect(
1342 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1343 .respond_with(json_encoded(MDSTokenResponse {
1344 access_token: "test-access-token".to_string(),
1345 expires_in: None,
1346 token_type: "Bearer".to_string(),
1347 })),
1348 );
1349 server.expect(
1350 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1351 .respond_with(status_code(200).body("test-client-email")),
1352 );
1353 server.expect(
1354 Expectation::matching(all_of![
1355 request::method_path(
1356 "POST",
1357 "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1358 ),
1359 request::headers(contains(("authorization", "Bearer test-access-token"))),
1360 ])
1361 .respond_with(json_encoded(json!({
1362 "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1363 }))),
1364 );
1365
1366 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1367
1368 let signer = Builder::default()
1369 .with_endpoint(&endpoint)
1370 .maybe_iam_endpoint_override(Some(endpoint))
1371 .without_access_boundary()
1372 .build_signer()?;
1373
1374 let client_email = signer.client_email().await?;
1375 assert_eq!(client_email, "test-client-email");
1376
1377 let signature = signer.sign(b"test").await?;
1378 assert_eq!(signature.as_ref(), b"signed_blob");
1379
1380 Ok(())
1381 }
1382
1383 #[tokio::test]
1384 #[parallel]
1385 #[cfg(google_cloud_unstable_trusted_boundaries)]
1386 async fn e2e_access_boundary() -> TestResult {
1387 use crate::credentials::tests::get_access_boundary_from_headers;
1388 use crate::mds::MDS_UNIVERSE_DOMAIN_URI;
1389
1390 let server = Server::run();
1391 server.expect(
1392 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1393 .respond_with(json_encoded(MDSTokenResponse {
1394 access_token: "test-access-token".to_string(),
1395 expires_in: None,
1396 token_type: "Bearer".to_string(),
1397 })),
1398 );
1399 server.expect(
1400 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1401 .respond_with(status_code(200).body("test-client-email")),
1402 );
1403 server.expect(
1404 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1405 .respond_with(status_code(404)),
1406 );
1407 server.expect(
1408 Expectation::matching(all_of![
1409 request::method_path(
1410 "GET",
1411 "/v1/projects/-/serviceAccounts/test-client-email/allowedLocations"
1412 ),
1413 request::headers(contains(("authorization", "Bearer test-access-token"))),
1414 ])
1415 .respond_with(json_encoded(json!({
1416 "locations": ["us-central1", "us-east1"],
1417 "encodedLocations": "0x1234"
1418 }))),
1419 );
1420
1421 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1422
1423 let creds = Builder::default()
1424 .with_endpoint(&endpoint)
1425 .maybe_iam_endpoint_override(Some(endpoint))
1426 .build_credentials()?;
1427
1428 creds.wait_for_boundary().await;
1430
1431 let headers = creds.headers(Extensions::new()).await?;
1432 let token = get_token_from_headers(headers.clone());
1433 let access_boundary = get_access_boundary_from_headers(headers);
1434 assert!(token.is_some(), "should have some token: {token:?}");
1435 assert_eq!(
1436 access_boundary.as_deref(),
1437 Some("0x1234"),
1438 "should be 0x1234 but found: {access_boundary:?}"
1439 );
1440
1441 Ok(())
1442 }
1443}