1use crate::access_boundary::CredentialsWithAccessBoundary;
76use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
77use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
78use crate::headers_util::AuthHeadersBuilder;
79use crate::mds::client::Client as MDSClient;
80use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
81use crate::token::{CachedTokenProvider, Token, TokenProvider};
82use crate::token_cache::TokenCache;
83use crate::{BuildResult, Result};
84use async_trait::async_trait;
85use google_cloud_gax::backoff_policy::BackoffPolicyArg;
86use google_cloud_gax::error::CredentialsError;
87use google_cloud_gax::retry_policy::RetryPolicyArg;
88use google_cloud_gax::retry_throttler::RetryThrottlerArg;
89use http::{Extensions, HeaderMap};
90use std::default::Default;
91use std::sync::{Arc, OnceLock};
92
93const MDS_NOT_FOUND_ERROR: &str = concat!(
95 "Could not fetch an auth token to authenticate with Google Cloud. ",
96 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
97 "and you have not configured local credentials for development and testing. ",
98 "To setup local credentials, run `gcloud auth application-default login`. ",
99 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
100);
101
102#[derive(Debug)]
103struct MDSCredentials<T>
104where
105 T: CachedTokenProvider,
106{
107 quota_project_id: Option<String>,
108 universe_domain_override: Option<String>,
109 universe_domain: OnceLock<Option<String>>,
110 token_provider: T,
111 mds_client: MDSClient,
112}
113
114#[derive(Debug)]
126pub struct Builder {
127 endpoint: Option<String>,
128 quota_project_id: Option<String>,
129 universe_domain: Option<String>,
130 scopes: Option<Vec<String>>,
131 created_by_adc: bool,
132 retry_builder: RetryTokenProviderBuilder,
133 iam_endpoint_override: Option<String>,
134 is_access_boundary_enabled: bool,
135}
136
137impl Default for Builder {
138 fn default() -> Self {
139 Self {
140 endpoint: None,
141 quota_project_id: None,
142 universe_domain: None,
143 scopes: None,
144 created_by_adc: false,
145 retry_builder: RetryTokenProviderBuilder::default(),
146 iam_endpoint_override: None,
147 is_access_boundary_enabled: true,
148 }
149 }
150}
151
152impl Builder {
153 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
168 self.endpoint = Some(endpoint.into());
169 self
170 }
171
172 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
181 self.quota_project_id = Some(quota_project_id.into());
182 self
183 }
184
185 #[allow(dead_code)]
189 pub(crate) fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
190 self.universe_domain = Some(universe_domain.into());
191 self
192 }
193
194 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
203 where
204 I: IntoIterator<Item = S>,
205 S: Into<String>,
206 {
207 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
208 self
209 }
210
211 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
226 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
227 self
228 }
229
230 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
246 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
247 self
248 }
249
250 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
271 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
272 self
273 }
274
275 #[cfg(test)]
276 fn maybe_iam_endpoint_override(mut self, iam_endpoint_override: Option<String>) -> Self {
277 self.iam_endpoint_override = iam_endpoint_override;
278 self
279 }
280
281 #[cfg(test)]
282 fn without_access_boundary(mut self) -> Self {
283 self.is_access_boundary_enabled = false;
284 self
285 }
286
287 pub(crate) fn from_adc() -> Self {
289 Self {
290 created_by_adc: true,
291 ..Default::default()
292 }
293 }
294
295 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
296 let tp = MDSAccessTokenProvider::builder()
297 .endpoint(self.endpoint)
298 .maybe_scopes(self.scopes)
299 .created_by_adc(self.created_by_adc)
300 .build();
301 self.retry_builder.build(tp)
302 }
303
304 pub fn build(self) -> BuildResult<Credentials> {
306 Ok(self.build_credentials()?.into())
307 }
308
309 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
325 Ok(self.build_credentials()?.into())
326 }
327
328 fn build_credentials(
329 self,
330 ) -> BuildResult<CredentialsWithAccessBoundary<MDSCredentials<TokenCache>>> {
331 let iam_endpoint = self.iam_endpoint_override.clone();
332 let is_access_boundary_enabled = self.is_access_boundary_enabled;
333 let mds_client = MDSClient::new(self.endpoint.clone());
334 let mdsc = MDSCredentials {
335 quota_project_id: self.quota_project_id.clone(),
336 universe_domain_override: self.universe_domain.clone(),
337 universe_domain: OnceLock::new(),
338 token_provider: TokenCache::new(self.build_token_provider()),
339 mds_client: mds_client.clone(),
340 };
341 if !is_access_boundary_enabled {
342 return Ok(CredentialsWithAccessBoundary::new_no_op(mdsc));
343 }
344 Ok(CredentialsWithAccessBoundary::new_for_mds(
345 mdsc,
346 mds_client,
347 iam_endpoint,
348 ))
349 }
350
351 pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
368 let client = MDSClient::new(self.endpoint.clone());
369 let iam_endpoint = self.iam_endpoint_override.clone();
370 let credentials = self.build()?;
371 let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
372 let signing_provider = iam_endpoint
373 .iter()
374 .fold(signing_provider, |signing_provider, endpoint| {
375 signing_provider.with_iam_endpoint_override(endpoint)
376 });
377 Ok(crate::signer::Signer {
378 inner: Arc::new(signing_provider),
379 })
380 }
381}
382
383#[async_trait::async_trait]
384impl<T> CredentialsProvider for MDSCredentials<T>
385where
386 T: CachedTokenProvider,
387{
388 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
389 let token = self.token_provider.token(extensions).await?;
390
391 AuthHeadersBuilder::new(&token)
392 .maybe_quota_project_id(self.quota_project_id.as_deref())
393 .build()
394 }
395
396 async fn universe_domain(&self) -> Option<String> {
397 if let Some(ud) = &self.universe_domain_override {
398 return Some(ud.clone());
399 }
400 if let Some(ud) = self.universe_domain.get() {
401 return ud.clone();
402 }
403
404 let response = self.mds_client.universe_domain().send().await;
406 match response {
407 Ok(universe_domain) => {
408 let _ = self.universe_domain.set(Some(universe_domain.clone()));
409 Some(universe_domain)
410 }
411 Err(e) => {
412 if !e.is_transient() {
413 let _ = self.universe_domain.set(None);
415 }
416 None
419 }
420 }
421 }
422}
423
424#[async_trait::async_trait]
425impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
426where
427 T: CachedTokenProvider,
428{
429 async fn access_token(&self) -> Result<AccessToken> {
430 let token = self.token_provider.token(Extensions::new()).await?;
431 token.into()
432 }
433}
434
435#[derive(Debug, Default)]
436struct MDSAccessTokenProviderBuilder {
437 scopes: Option<Vec<String>>,
438 endpoint: Option<String>,
439 created_by_adc: bool,
440}
441
442impl MDSAccessTokenProviderBuilder {
443 fn build(self) -> MDSAccessTokenProvider {
444 MDSAccessTokenProvider {
445 client: MDSClient::new(self.endpoint),
446 scopes: self.scopes,
447 created_by_adc: self.created_by_adc,
448 }
449 }
450
451 fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
452 self.scopes = v;
453 self
454 }
455
456 fn endpoint<T>(mut self, v: Option<T>) -> Self
457 where
458 T: Into<String>,
459 {
460 self.endpoint = v.map(Into::into);
461 self
462 }
463
464 fn created_by_adc(mut self, v: bool) -> Self {
465 self.created_by_adc = v;
466 self
467 }
468}
469
470#[derive(Debug, Clone)]
471struct MDSAccessTokenProvider {
472 scopes: Option<Vec<String>>,
473 client: MDSClient,
474 created_by_adc: bool,
475}
476
477impl MDSAccessTokenProvider {
478 fn builder() -> MDSAccessTokenProviderBuilder {
479 MDSAccessTokenProviderBuilder::default()
480 }
481
482 fn error_message(&self) -> &str {
490 if self.use_adc_message() {
491 MDS_NOT_FOUND_ERROR
492 } else {
493 "failed to fetch token"
494 }
495 }
496
497 fn use_adc_message(&self) -> bool {
498 self.created_by_adc && self.client.is_default_endpoint
499 }
500}
501
502#[async_trait]
503impl TokenProvider for MDSAccessTokenProvider {
504 async fn token(&self) -> Result<Token> {
505 self.client
506 .access_token(self.scopes.clone())
507 .send()
508 .await
509 .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::credentials::QUOTA_PROJECT_KEY;
517 use crate::credentials::tests::{
518 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
519 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
520 get_token_type_from_headers,
521 };
522 use crate::errors;
523 use crate::errors::CredentialsError;
524 use crate::mds::client::MDSTokenResponse;
525 use crate::mds::{
526 GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, MDS_UNIVERSE_DOMAIN_URI, METADATA_ROOT,
527 };
528 use crate::token::tests::MockTokenProvider;
529 use crate::token_cache::TokenCache;
530 use base64::{Engine, prelude::BASE64_STANDARD};
531 use http::HeaderValue;
532 use http::header::AUTHORIZATION;
533 use httptest::cycle;
534 use httptest::matchers::{all_of, contains, request, url_decoded};
535 use httptest::responders::{json_encoded, status_code};
536 use httptest::{Expectation, Server};
537 use reqwest::StatusCode;
538 use scoped_env::ScopedEnv;
539 use serde_json::json;
540 use serial_test::{parallel, serial};
541 use std::error::Error;
542 use std::time::Duration;
543 use test_case::test_case;
544 use tokio::time::Instant;
545 use url::Url;
546
547 type TestResult = anyhow::Result<()>;
548
549 #[tokio::test]
550 #[parallel]
551 async fn test_mds_retries_on_transient_failures() -> TestResult {
552 let mut server = Server::run();
553 server.expect(
554 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
555 .times(3)
556 .respond_with(status_code(503)),
557 );
558
559 let provider = Builder::default()
560 .with_endpoint(format!("http://{}", server.addr()))
561 .with_retry_policy(get_mock_auth_retry_policy(3))
562 .with_backoff_policy(get_mock_backoff_policy())
563 .with_retry_throttler(get_mock_retry_throttler())
564 .build_token_provider();
565
566 let err = provider.token().await.unwrap_err();
567 assert!(err.is_transient(), "{err:?}");
568 server.verify_and_clear();
569 Ok(())
570 }
571
572 #[tokio::test]
573 #[parallel]
574 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
575 let mut server = Server::run();
576 server.expect(
577 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
578 .times(1)
579 .respond_with(status_code(401)),
580 );
581
582 let provider = Builder::default()
583 .with_endpoint(format!("http://{}", server.addr()))
584 .with_retry_policy(get_mock_auth_retry_policy(1))
585 .with_backoff_policy(get_mock_backoff_policy())
586 .with_retry_throttler(get_mock_retry_throttler())
587 .build_token_provider();
588
589 let err = provider.token().await.unwrap_err();
590 assert!(!err.is_transient());
591 server.verify_and_clear();
592 Ok(())
593 }
594
595 #[tokio::test]
596 #[parallel]
597 async fn test_mds_retries_for_success() -> TestResult {
598 let mut server = Server::run();
599 let response = MDSTokenResponse {
600 access_token: "test-access-token".to_string(),
601 expires_in: Some(3600),
602 token_type: "test-token-type".to_string(),
603 };
604
605 server.expect(
606 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
607 .times(3)
608 .respond_with(cycle![
609 status_code(503).body("try-again"),
610 status_code(503).body("try-again"),
611 status_code(200)
612 .append_header("Content-Type", "application/json")
613 .body(serde_json::to_string(&response).unwrap()),
614 ]),
615 );
616
617 let provider = Builder::default()
618 .with_endpoint(format!("http://{}", server.addr()))
619 .with_retry_policy(get_mock_auth_retry_policy(3))
620 .with_backoff_policy(get_mock_backoff_policy())
621 .with_retry_throttler(get_mock_retry_throttler())
622 .build_token_provider();
623
624 let token = provider.token().await?;
625 assert_eq!(token.token, "test-access-token");
626
627 server.verify_and_clear();
628 Ok(())
629 }
630
631 #[test]
632 #[parallel]
633 fn validate_default_endpoint_urls() {
634 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
635 assert!(
636 default_endpoint_address.is_ok(),
637 "{default_endpoint_address:?}"
638 );
639
640 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
641 assert!(token_endpoint_address.is_ok(), "{token_endpoint_address:?}");
642 }
643
644 #[tokio::test]
645 #[parallel]
646 async fn headers_success() -> TestResult {
647 let token = Token {
648 token: "test-token".to_string(),
649 token_type: "Bearer".to_string(),
650 expires_at: None,
651 metadata: None,
652 };
653
654 let mut mock = MockTokenProvider::new();
655 mock.expect_token().times(1).return_once(|| Ok(token));
656
657 let mdsc = MDSCredentials {
658 quota_project_id: None,
659 token_provider: TokenCache::new(mock),
660 universe_domain_override: None,
661 universe_domain: OnceLock::new(),
662 mds_client: MDSClient::new(None),
663 };
664
665 let mut extensions = Extensions::new();
666 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
667 let (headers, entity_tag) = match cached_headers {
668 CacheableResource::New { entity_tag, data } => (data, entity_tag),
669 CacheableResource::NotModified => unreachable!("expecting new headers"),
670 };
671 let token = headers.get(AUTHORIZATION).unwrap();
672 assert_eq!(headers.len(), 1, "{headers:?}");
673 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
674 assert!(token.is_sensitive());
675
676 extensions.insert(entity_tag);
677
678 let cached_headers = mdsc.headers(extensions).await?;
679
680 match cached_headers {
681 CacheableResource::New { .. } => unreachable!("expecting new headers"),
682 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
683 };
684 Ok(())
685 }
686
687 #[tokio::test]
688 #[parallel]
689 async fn access_token_success() -> TestResult {
690 let server = Server::run();
691 let response = MDSTokenResponse {
692 access_token: "test-access-token".to_string(),
693 expires_in: Some(3600),
694 token_type: "Bearer".to_string(),
695 };
696 server.expect(
697 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
698 .respond_with(json_encoded(response)),
699 );
700
701 let creds = Builder::default()
702 .with_endpoint(format!("http://{}", server.addr()))
703 .without_access_boundary()
704 .build_access_token_credentials()
705 .unwrap();
706
707 let access_token = creds.access_token().await.unwrap();
708 assert_eq!(access_token.token, "test-access-token");
709
710 Ok(())
711 }
712
713 #[tokio::test]
714 #[parallel]
715 async fn headers_failure() {
716 let mut mock = MockTokenProvider::new();
717 mock.expect_token()
718 .times(1)
719 .return_once(|| Err(errors::non_retryable_from_str("fail")));
720
721 let mdsc = MDSCredentials {
722 quota_project_id: None,
723 token_provider: TokenCache::new(mock),
724 universe_domain_override: None,
725 universe_domain: OnceLock::new(),
726 mds_client: MDSClient::new(None),
727 };
728 let result = mdsc.headers(Extensions::new()).await;
729 assert!(result.is_err(), "{result:?}");
730 }
731
732 #[test]
733 #[parallel]
734 fn error_message_with_adc() {
735 let provider = MDSAccessTokenProvider::builder()
736 .created_by_adc(true)
737 .build();
738
739 let want = MDS_NOT_FOUND_ERROR;
740 let got = provider.error_message();
741 assert!(got.contains(want), "{got}, {provider:?}");
742 }
743
744 #[test_case(false, false)]
745 #[test_case(false, true)]
746 #[test_case(true, true)]
747 fn error_message_without_adc(adc: bool, overridden: bool) {
748 let endpoint = if overridden {
749 Some("http://127.0.0.1")
750 } else {
751 None
752 };
753 let provider = MDSAccessTokenProvider::builder()
754 .endpoint(endpoint)
755 .created_by_adc(adc)
756 .build();
757
758 let not_want = MDS_NOT_FOUND_ERROR;
759 let got = provider.error_message();
760 assert!(!got.contains(not_want), "{got}, {provider:?}");
761 }
762
763 #[tokio::test]
764 #[serial]
765 async fn adc_no_mds() -> TestResult {
766 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
767 return Ok(());
769 };
770
771 let original_err = err.source().unwrap();
772 assert!(
773 original_err.to_string().contains("application-default"),
774 "display={err}, debug={err:?}"
775 );
776
777 Ok(())
778 }
779
780 #[tokio::test]
781 #[serial]
782 async fn adc_overridden_mds() -> TestResult {
783 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
784
785 let err = Builder::from_adc()
786 .build_token_provider()
787 .token()
788 .await
789 .unwrap_err();
790
791 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
792
793 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
794 assert!(original_err.is_transient());
795 assert!(
796 !original_err.to_string().contains("application-default"),
797 "display={err}, debug={err:?}"
798 );
799 let source = find_source_error::<reqwest::Error>(&err);
800 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
801
802 Ok(())
803 }
804
805 #[tokio::test]
806 #[serial]
807 async fn builder_no_mds() -> TestResult {
808 let Err(e) = Builder::default().build_token_provider().token().await else {
809 return Ok(());
811 };
812
813 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
814 assert!(
815 !format!("{:?}", original_err.source()).contains("application-default"),
816 "{e:?}"
817 );
818
819 Ok(())
820 }
821
822 #[tokio::test]
823 #[serial]
824 async fn test_gce_metadata_host_env_var() -> TestResult {
825 let server = Server::run();
826 let scopes = ["scope1", "scope2"];
827 let response = MDSTokenResponse {
828 access_token: "test-access-token".to_string(),
829 expires_in: Some(3600),
830 token_type: "test-token-type".to_string(),
831 };
832 server.expect(
833 Expectation::matching(all_of![
834 request::path(format!("{MDS_DEFAULT_URI}/token")),
835 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
836 ])
837 .respond_with(json_encoded(response)),
838 );
839
840 let addr = server.addr().to_string();
841 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
842 let mdsc = Builder::default()
843 .with_scopes(["scope1", "scope2"])
844 .without_access_boundary()
845 .build()
846 .unwrap();
847 let headers = mdsc.headers(Extensions::new()).await.unwrap();
848 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
849
850 assert_eq!(
851 get_token_from_headers(headers).unwrap(),
852 "test-access-token"
853 );
854 Ok(())
855 }
856
857 #[tokio::test]
858 #[parallel]
859 async fn headers_success_with_quota_project() -> TestResult {
860 let server = Server::run();
861 let scopes = ["scope1", "scope2"];
862 let response = MDSTokenResponse {
863 access_token: "test-access-token".to_string(),
864 expires_in: Some(3600),
865 token_type: "test-token-type".to_string(),
866 };
867 server.expect(
868 Expectation::matching(all_of![
869 request::path(format!("{MDS_DEFAULT_URI}/token")),
870 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
871 ])
872 .respond_with(json_encoded(response)),
873 );
874
875 let mdsc = Builder::default()
876 .with_scopes(["scope1", "scope2"])
877 .with_endpoint(format!("http://{}", server.addr()))
878 .with_quota_project_id("test-project")
879 .without_access_boundary()
880 .build()?;
881
882 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
883 let token = headers.get(AUTHORIZATION).unwrap();
884 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
885
886 assert_eq!(headers.len(), 2, "{headers:?}");
887 assert_eq!(
888 token,
889 HeaderValue::from_static("test-token-type test-access-token")
890 );
891 assert!(token.is_sensitive());
892 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
893 assert!(!quota_project.is_sensitive());
894
895 Ok(())
896 }
897
898 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
899 #[parallel]
900 async fn token_caching() -> TestResult {
901 let server = Server::run();
902 let scopes = vec!["scope1".to_string()];
903 let response = MDSTokenResponse {
904 access_token: "test-access-token".to_string(),
905 expires_in: Some(3600),
906 token_type: "test-token-type".to_string(),
907 };
908 server.expect(
909 Expectation::matching(all_of![
910 request::path(format!("{MDS_DEFAULT_URI}/token")),
911 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
912 ])
913 .times(1)
914 .respond_with(json_encoded(response)),
915 );
916
917 let mdsc = Builder::default()
918 .with_scopes(scopes)
919 .with_endpoint(format!("http://{}", server.addr()))
920 .without_access_boundary()
921 .build()?;
922 let headers = mdsc.headers(Extensions::new()).await?;
923 assert_eq!(
924 get_token_from_headers(headers).unwrap(),
925 "test-access-token"
926 );
927 let headers = mdsc.headers(Extensions::new()).await?;
928 assert_eq!(
929 get_token_from_headers(headers).unwrap(),
930 "test-access-token"
931 );
932
933 Ok(())
934 }
935
936 #[tokio::test(start_paused = true)]
937 #[parallel]
938 async fn token_provider_full() -> TestResult {
939 let server = Server::run();
940 let scopes = vec!["scope1".to_string()];
941 let response = MDSTokenResponse {
942 access_token: "test-access-token".to_string(),
943 expires_in: Some(3600),
944 token_type: "test-token-type".to_string(),
945 };
946 server.expect(
947 Expectation::matching(all_of![
948 request::path(format!("{MDS_DEFAULT_URI}/token")),
949 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
950 ])
951 .respond_with(json_encoded(response)),
952 );
953
954 let token = Builder::default()
955 .with_endpoint(format!("http://{}", server.addr()))
956 .with_scopes(scopes)
957 .build_token_provider()
958 .token()
959 .await?;
960
961 let now = tokio::time::Instant::now();
962 assert_eq!(token.token, "test-access-token");
963 assert_eq!(token.token_type, "test-token-type");
964 assert!(
965 token
966 .expires_at
967 .is_some_and(|d| d >= now + Duration::from_secs(3600))
968 );
969
970 Ok(())
971 }
972
973 #[tokio::test(start_paused = true)]
974 #[parallel]
975 async fn token_provider_full_no_scopes() -> TestResult {
976 let server = Server::run();
977 let response = MDSTokenResponse {
978 access_token: "test-access-token".to_string(),
979 expires_in: Some(3600),
980 token_type: "test-token-type".to_string(),
981 };
982 server.expect(
983 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
984 .respond_with(json_encoded(response)),
985 );
986
987 let token = Builder::default()
988 .with_endpoint(format!("http://{}", server.addr()))
989 .build_token_provider()
990 .token()
991 .await?;
992
993 let now = Instant::now();
994 assert_eq!(token.token, "test-access-token");
995 assert_eq!(token.token_type, "test-token-type");
996 assert!(
997 token
998 .expires_at
999 .is_some_and(|d| d == now + Duration::from_secs(3600))
1000 );
1001
1002 Ok(())
1003 }
1004
1005 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1006 #[parallel]
1007 async fn credential_provider_full() -> TestResult {
1008 let server = Server::run();
1009 let scopes = vec!["scope1".to_string()];
1010 let response = MDSTokenResponse {
1011 access_token: "test-access-token".to_string(),
1012 expires_in: None,
1013 token_type: "test-token-type".to_string(),
1014 };
1015 server.expect(
1016 Expectation::matching(all_of![
1017 request::path(format!("{MDS_DEFAULT_URI}/token")),
1018 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1019 ])
1020 .respond_with(json_encoded(response)),
1021 );
1022
1023 let mdsc = Builder::default()
1024 .with_endpoint(format!("http://{}", server.addr()))
1025 .with_scopes(scopes)
1026 .without_access_boundary()
1027 .build()?;
1028 let headers = mdsc.headers(Extensions::new()).await?;
1029 assert_eq!(
1030 get_token_from_headers(headers.clone()).unwrap(),
1031 "test-access-token"
1032 );
1033 assert_eq!(
1034 get_token_type_from_headers(headers).unwrap(),
1035 "test-token-type"
1036 );
1037
1038 Ok(())
1039 }
1040
1041 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1042 #[parallel]
1043 async fn credentials_headers_retryable_error() -> TestResult {
1044 let server = Server::run();
1045 let scopes = vec!["scope1".to_string()];
1046 server.expect(
1047 Expectation::matching(all_of![
1048 request::path(format!("{MDS_DEFAULT_URI}/token")),
1049 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1050 ])
1051 .respond_with(status_code(503)),
1052 );
1053
1054 let mdsc = Builder::default()
1055 .with_endpoint(format!("http://{}", server.addr()))
1056 .with_scopes(scopes)
1057 .without_access_boundary()
1058 .build()?;
1059 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1060 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1061 assert!(original_err.is_transient());
1062 let source = find_source_error::<reqwest::Error>(&err);
1063 assert!(
1064 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1065 "{err:?}"
1066 );
1067
1068 Ok(())
1069 }
1070
1071 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1072 #[parallel]
1073 async fn credentials_headers_nonretryable_error() -> TestResult {
1074 let server = Server::run();
1075 let scopes = vec!["scope1".to_string()];
1076 server.expect(
1077 Expectation::matching(all_of![
1078 request::path(format!("{MDS_DEFAULT_URI}/token")),
1079 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1080 ])
1081 .respond_with(status_code(401)),
1082 );
1083
1084 let mdsc = Builder::default()
1085 .with_endpoint(format!("http://{}", server.addr()))
1086 .with_scopes(scopes)
1087 .without_access_boundary()
1088 .build()?;
1089
1090 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1091 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1092 assert!(!original_err.is_transient());
1093 let source = find_source_error::<reqwest::Error>(&err);
1094 assert!(
1095 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1096 "{err:?}"
1097 );
1098
1099 Ok(())
1100 }
1101
1102 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1103 #[parallel]
1104 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1105 let server = Server::run();
1106 let scopes = vec!["scope1".to_string()];
1107 server.expect(
1108 Expectation::matching(all_of![
1109 request::path(format!("{MDS_DEFAULT_URI}/token")),
1110 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1111 ])
1112 .respond_with(json_encoded("bad json")),
1113 );
1114
1115 let mdsc = Builder::default()
1116 .with_endpoint(format!("http://{}", server.addr()))
1117 .with_scopes(scopes)
1118 .without_access_boundary()
1119 .build()?;
1120
1121 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1122 assert!(!e.is_transient());
1123
1124 Ok(())
1125 }
1126
1127 #[tokio::test]
1128 #[parallel]
1129 async fn get_default_universe_domain() -> TestResult {
1130 let server = Server::run();
1131 server.expect(
1132 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1133 .respond_with(status_code(404)),
1134 );
1135
1136 let mut mock = MockTokenProvider::new();
1137 mock.expect_token()
1138 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1139
1140 let creds = MDSCredentials {
1141 quota_project_id: None,
1142 universe_domain_override: None,
1143 universe_domain: OnceLock::new(),
1144 token_provider: TokenCache::new(mock),
1145 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1146 };
1147
1148 let universe_domain = creds.universe_domain().await;
1149 assert!(universe_domain.is_none());
1150 Ok(())
1151 }
1152
1153 #[tokio::test]
1154 #[parallel]
1155 async fn get_universe_domain_override() -> TestResult {
1156 let creds = Builder::default()
1157 .with_universe_domain("my-universe-domain.com")
1158 .without_access_boundary()
1159 .build()?;
1160 let universe_domain = creds.universe_domain().await;
1161 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1162 Ok(())
1163 }
1164
1165 #[tokio::test]
1166 #[parallel]
1167 async fn get_universe_domain_from_mds() -> TestResult {
1168 let server = Server::run();
1169 server.expect(
1170 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1171 .respond_with(status_code(200).body("my-universe-domain.com")),
1172 );
1173
1174 let mut mock = MockTokenProvider::new();
1175 mock.expect_token()
1176 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1177
1178 let creds = MDSCredentials {
1179 quota_project_id: None,
1180 universe_domain_override: None,
1181 universe_domain: OnceLock::new(),
1182 token_provider: TokenCache::new(mock),
1183 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1184 };
1185 let universe_domain = creds.universe_domain().await;
1186 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1187 Ok(())
1188 }
1189
1190 #[tokio::test]
1191 #[parallel]
1192 async fn get_universe_domain_caching() -> TestResult {
1193 let server = Server::run();
1194 server.expect(
1195 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1196 .times(2)
1197 .respond_with(cycle![
1198 status_code(503).body("transient error"),
1199 status_code(200).body("my-universe-domain.com"),
1200 ]),
1201 );
1202
1203 let mut mock = MockTokenProvider::new();
1204 mock.expect_token()
1205 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1206
1207 let creds = MDSCredentials {
1208 quota_project_id: None,
1209 universe_domain_override: None,
1210 universe_domain: OnceLock::new(),
1211 token_provider: TokenCache::new(mock),
1212 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1213 };
1214
1215 let universe_domain = creds.universe_domain().await;
1216 assert_eq!(universe_domain, None);
1217
1218 let universe_domain = creds.universe_domain().await;
1219 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1220
1221 let universe_domain = creds.universe_domain().await;
1222 assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
1223
1224 Ok(())
1225 }
1226
1227 #[tokio::test]
1228 #[parallel]
1229 async fn get_universe_domain_caching_permanent_error() -> TestResult {
1230 let server = Server::run();
1231 server.expect(
1232 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1233 .times(1)
1234 .respond_with(status_code(404).body("permanent error")),
1235 );
1236
1237 let mut mock = MockTokenProvider::new();
1238 mock.expect_token()
1239 .returning(|| Err(crate::errors::non_retryable_from_str("fail")));
1240
1241 let creds = MDSCredentials {
1242 quota_project_id: None,
1243 universe_domain_override: None,
1244 universe_domain: OnceLock::new(),
1245 token_provider: TokenCache::new(mock),
1246 mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
1247 };
1248
1249 let universe_domain = creds.universe_domain().await;
1250 assert_eq!(universe_domain, None);
1251
1252 let universe_domain = creds.universe_domain().await;
1253 assert_eq!(universe_domain, None);
1254
1255 Ok(())
1256 }
1257
1258 #[tokio::test]
1259 #[parallel]
1260 async fn get_mds_signer() -> TestResult {
1261 let server = Server::run();
1262 server.expect(
1263 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1264 .respond_with(json_encoded(MDSTokenResponse {
1265 access_token: "test-access-token".to_string(),
1266 expires_in: None,
1267 token_type: "Bearer".to_string(),
1268 })),
1269 );
1270 server.expect(
1271 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1272 .respond_with(status_code(200).body("test-client-email")),
1273 );
1274 server.expect(
1275 Expectation::matching(all_of![
1276 request::method_path(
1277 "POST",
1278 "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1279 ),
1280 request::headers(contains(("authorization", "Bearer test-access-token"))),
1281 ])
1282 .respond_with(json_encoded(json!({
1283 "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1284 }))),
1285 );
1286
1287 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1288
1289 let signer = Builder::default()
1290 .with_endpoint(&endpoint)
1291 .maybe_iam_endpoint_override(Some(endpoint))
1292 .without_access_boundary()
1293 .build_signer()?;
1294
1295 let client_email = signer.client_email().await?;
1296 assert_eq!(client_email, "test-client-email");
1297
1298 let signature = signer.sign(b"test").await?;
1299 assert_eq!(signature.as_ref(), b"signed_blob");
1300
1301 Ok(())
1302 }
1303
1304 #[tokio::test]
1305 #[parallel]
1306 #[cfg(google_cloud_unstable_trusted_boundaries)]
1307 async fn e2e_access_boundary() -> TestResult {
1308 use crate::credentials::tests::get_access_boundary_from_headers;
1309 use crate::mds::MDS_UNIVERSE_DOMAIN_URI;
1310
1311 let server = Server::run();
1312 server.expect(
1313 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1314 .respond_with(json_encoded(MDSTokenResponse {
1315 access_token: "test-access-token".to_string(),
1316 expires_in: None,
1317 token_type: "Bearer".to_string(),
1318 })),
1319 );
1320 server.expect(
1321 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1322 .respond_with(status_code(200).body("test-client-email")),
1323 );
1324 server.expect(
1325 Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
1326 .respond_with(status_code(404)),
1327 );
1328 server.expect(
1329 Expectation::matching(all_of![
1330 request::method_path(
1331 "GET",
1332 "/v1/projects/-/serviceAccounts/test-client-email/allowedLocations"
1333 ),
1334 request::headers(contains(("authorization", "Bearer test-access-token"))),
1335 ])
1336 .respond_with(json_encoded(json!({
1337 "locations": ["us-central1", "us-east1"],
1338 "encodedLocations": "0x1234"
1339 }))),
1340 );
1341
1342 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1343
1344 let creds = Builder::default()
1345 .with_endpoint(&endpoint)
1346 .maybe_iam_endpoint_override(Some(endpoint))
1347 .build_credentials()?;
1348
1349 creds.wait_for_boundary().await;
1351
1352 let headers = creds.headers(Extensions::new()).await?;
1353 let token = get_token_from_headers(headers.clone());
1354 let access_boundary = get_access_boundary_from_headers(headers);
1355 assert!(token.is_some(), "should have some token: {token:?}");
1356 assert_eq!(
1357 access_boundary.as_deref(),
1358 Some("0x1234"),
1359 "should be 0x1234 but found: {access_boundary:?}"
1360 );
1361
1362 Ok(())
1363 }
1364}