1use crate::access_boundary::CredentialsWithAccessBoundary;
78use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
79use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
80use crate::headers_util::AuthHeadersBuilder;
81use crate::mds::client::Client as MDSClient;
82use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
83use crate::token::{CachedTokenProvider, Token, TokenProvider};
84use crate::token_cache::TokenCache;
85use crate::{BuildResult, Result};
86use async_trait::async_trait;
87use google_cloud_gax::backoff_policy::BackoffPolicyArg;
88use google_cloud_gax::error::CredentialsError;
89use google_cloud_gax::retry_policy::RetryPolicyArg;
90use google_cloud_gax::retry_throttler::RetryThrottlerArg;
91use http::{Extensions, HeaderMap};
92use std::default::Default;
93use std::sync::Arc;
94
95const MDS_NOT_FOUND_ERROR: &str = concat!(
97 "Could not fetch an auth token to authenticate with Google Cloud. ",
98 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
99 "and you have not configured local credentials for development and testing. ",
100 "To setup local credentials, run `gcloud auth application-default login`. ",
101 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
102);
103
104#[derive(Debug)]
105struct MDSCredentials<T>
106where
107 T: CachedTokenProvider,
108{
109 quota_project_id: Option<String>,
110 token_provider: T,
111}
112
113#[derive(Debug)]
125pub struct Builder {
126 endpoint: Option<String>,
127 quota_project_id: Option<String>,
128 scopes: Option<Vec<String>>,
129 created_by_adc: bool,
130 retry_builder: RetryTokenProviderBuilder,
131 iam_endpoint_override: Option<String>,
132 is_access_boundary_enabled: bool,
133}
134
135impl Default for Builder {
136 fn default() -> Self {
137 Self {
138 endpoint: None,
139 quota_project_id: None,
140 scopes: None,
141 created_by_adc: false,
142 retry_builder: RetryTokenProviderBuilder::default(),
143 iam_endpoint_override: None,
144 is_access_boundary_enabled: true,
145 }
146 }
147}
148
149impl Builder {
150 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
165 self.endpoint = Some(endpoint.into());
166 self
167 }
168
169 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
178 self.quota_project_id = Some(quota_project_id.into());
179 self
180 }
181
182 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
191 where
192 I: IntoIterator<Item = S>,
193 S: Into<String>,
194 {
195 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
196 self
197 }
198
199 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
214 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
215 self
216 }
217
218 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
234 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
235 self
236 }
237
238 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
259 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
260 self
261 }
262
263 #[cfg(test)]
264 fn maybe_iam_endpoint_override(mut self, iam_endpoint_override: Option<String>) -> Self {
265 self.iam_endpoint_override = iam_endpoint_override;
266 self
267 }
268
269 #[cfg(test)]
270 fn without_access_boundary(mut self) -> Self {
271 self.is_access_boundary_enabled = false;
272 self
273 }
274
275 pub(crate) fn from_adc() -> Self {
277 Self {
278 created_by_adc: true,
279 ..Default::default()
280 }
281 }
282
283 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
284 let tp = MDSAccessTokenProvider::builder()
285 .endpoint(self.endpoint)
286 .maybe_scopes(self.scopes)
287 .created_by_adc(self.created_by_adc)
288 .build();
289 self.retry_builder.build(tp)
290 }
291
292 pub fn build(self) -> BuildResult<Credentials> {
294 Ok(self.build_credentials()?.into())
295 }
296
297 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
314 Ok(self.build_credentials()?.into())
315 }
316
317 fn build_credentials(
318 self,
319 ) -> BuildResult<CredentialsWithAccessBoundary<MDSCredentials<TokenCache>>> {
320 let iam_endpoint = self.iam_endpoint_override.clone();
321 let is_access_boundary_enabled = self.is_access_boundary_enabled;
322 let mds_client = MDSClient::new(self.endpoint.clone());
323 let mdsc = MDSCredentials {
324 quota_project_id: self.quota_project_id.clone(),
325 token_provider: TokenCache::new(self.build_token_provider()),
326 };
327 if !is_access_boundary_enabled {
328 return Ok(CredentialsWithAccessBoundary::new_no_op(mdsc));
329 }
330 Ok(CredentialsWithAccessBoundary::new_for_mds(
331 mdsc,
332 mds_client,
333 iam_endpoint,
334 ))
335 }
336
337 pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
355 let client = MDSClient::new(self.endpoint.clone());
356 let iam_endpoint = self.iam_endpoint_override.clone();
357 let credentials = self.build()?;
358 let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
359 let signing_provider = iam_endpoint
360 .iter()
361 .fold(signing_provider, |signing_provider, endpoint| {
362 signing_provider.with_iam_endpoint_override(endpoint)
363 });
364 Ok(crate::signer::Signer {
365 inner: Arc::new(signing_provider),
366 })
367 }
368}
369
370#[async_trait::async_trait]
371impl<T> CredentialsProvider for MDSCredentials<T>
372where
373 T: CachedTokenProvider,
374{
375 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
376 let token = self.token_provider.token(extensions).await?;
377
378 AuthHeadersBuilder::new(&token)
379 .maybe_quota_project_id(self.quota_project_id.as_deref())
380 .build()
381 }
382}
383
384#[async_trait::async_trait]
385impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
386where
387 T: CachedTokenProvider,
388{
389 async fn access_token(&self) -> Result<AccessToken> {
390 let token = self.token_provider.token(Extensions::new()).await?;
391 token.into()
392 }
393}
394
395#[derive(Debug, Default)]
396struct MDSAccessTokenProviderBuilder {
397 scopes: Option<Vec<String>>,
398 endpoint: Option<String>,
399 created_by_adc: bool,
400}
401
402impl MDSAccessTokenProviderBuilder {
403 fn build(self) -> MDSAccessTokenProvider {
404 MDSAccessTokenProvider {
405 client: MDSClient::new(self.endpoint),
406 scopes: self.scopes,
407 created_by_adc: self.created_by_adc,
408 }
409 }
410
411 fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
412 self.scopes = v;
413 self
414 }
415
416 fn endpoint<T>(mut self, v: Option<T>) -> Self
417 where
418 T: Into<String>,
419 {
420 self.endpoint = v.map(Into::into);
421 self
422 }
423
424 fn created_by_adc(mut self, v: bool) -> Self {
425 self.created_by_adc = v;
426 self
427 }
428}
429
430#[derive(Debug, Clone)]
431struct MDSAccessTokenProvider {
432 scopes: Option<Vec<String>>,
433 client: MDSClient,
434 created_by_adc: bool,
435}
436
437impl MDSAccessTokenProvider {
438 fn builder() -> MDSAccessTokenProviderBuilder {
439 MDSAccessTokenProviderBuilder::default()
440 }
441
442 fn error_message(&self) -> &str {
450 if self.use_adc_message() {
451 MDS_NOT_FOUND_ERROR
452 } else {
453 "failed to fetch token"
454 }
455 }
456
457 fn use_adc_message(&self) -> bool {
458 self.created_by_adc && self.client.is_default_endpoint
459 }
460}
461
462#[async_trait]
463impl TokenProvider for MDSAccessTokenProvider {
464 async fn token(&self) -> Result<Token> {
465 self.client
466 .access_token(self.scopes.clone())
467 .await
468 .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
476 use crate::credentials::QUOTA_PROJECT_KEY;
477 use crate::credentials::tests::{
478 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
479 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
480 get_token_type_from_headers,
481 };
482 use crate::errors;
483 use crate::errors::CredentialsError;
484 use crate::mds::client::MDSTokenResponse;
485 use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
486 use crate::token::tests::MockTokenProvider;
487 use base64::{Engine, prelude::BASE64_STANDARD};
488 use http::HeaderValue;
489 use http::header::AUTHORIZATION;
490 use httptest::cycle;
491 use httptest::matchers::{all_of, contains, request, url_decoded};
492 use httptest::responders::{json_encoded, status_code};
493 use httptest::{Expectation, Server};
494 use reqwest::StatusCode;
495 use scoped_env::ScopedEnv;
496 use serde_json::json;
497 use serial_test::{parallel, serial};
498 use std::error::Error;
499 use std::time::Duration;
500 use test_case::test_case;
501 use tokio::time::Instant;
502 use url::Url;
503
504 type TestResult = anyhow::Result<()>;
505
506 #[tokio::test]
507 #[parallel]
508 async fn test_mds_retries_on_transient_failures() -> TestResult {
509 let mut server = Server::run();
510 server.expect(
511 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
512 .times(3)
513 .respond_with(status_code(503)),
514 );
515
516 let provider = Builder::default()
517 .with_endpoint(format!("http://{}", server.addr()))
518 .with_retry_policy(get_mock_auth_retry_policy(3))
519 .with_backoff_policy(get_mock_backoff_policy())
520 .with_retry_throttler(get_mock_retry_throttler())
521 .build_token_provider();
522
523 let err = provider.token().await.unwrap_err();
524 assert!(err.is_transient(), "{err:?}");
525 server.verify_and_clear();
526 Ok(())
527 }
528
529 #[tokio::test]
530 #[parallel]
531 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
532 let mut server = Server::run();
533 server.expect(
534 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
535 .times(1)
536 .respond_with(status_code(401)),
537 );
538
539 let provider = Builder::default()
540 .with_endpoint(format!("http://{}", server.addr()))
541 .with_retry_policy(get_mock_auth_retry_policy(1))
542 .with_backoff_policy(get_mock_backoff_policy())
543 .with_retry_throttler(get_mock_retry_throttler())
544 .build_token_provider();
545
546 let err = provider.token().await.unwrap_err();
547 assert!(!err.is_transient());
548 server.verify_and_clear();
549 Ok(())
550 }
551
552 #[tokio::test]
553 #[parallel]
554 async fn test_mds_retries_for_success() -> TestResult {
555 let mut server = Server::run();
556 let response = MDSTokenResponse {
557 access_token: "test-access-token".to_string(),
558 expires_in: Some(3600),
559 token_type: "test-token-type".to_string(),
560 };
561
562 server.expect(
563 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
564 .times(3)
565 .respond_with(cycle![
566 status_code(503).body("try-again"),
567 status_code(503).body("try-again"),
568 status_code(200)
569 .append_header("Content-Type", "application/json")
570 .body(serde_json::to_string(&response).unwrap()),
571 ]),
572 );
573
574 let provider = Builder::default()
575 .with_endpoint(format!("http://{}", server.addr()))
576 .with_retry_policy(get_mock_auth_retry_policy(3))
577 .with_backoff_policy(get_mock_backoff_policy())
578 .with_retry_throttler(get_mock_retry_throttler())
579 .build_token_provider();
580
581 let token = provider.token().await?;
582 assert_eq!(token.token, "test-access-token");
583
584 server.verify_and_clear();
585 Ok(())
586 }
587
588 #[test]
589 #[parallel]
590 fn validate_default_endpoint_urls() {
591 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
592 assert!(
593 default_endpoint_address.is_ok(),
594 "{default_endpoint_address:?}"
595 );
596
597 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
598 assert!(token_endpoint_address.is_ok(), "{token_endpoint_address:?}");
599 }
600
601 #[tokio::test]
602 #[parallel]
603 async fn headers_success() -> TestResult {
604 let token = Token {
605 token: "test-token".to_string(),
606 token_type: "Bearer".to_string(),
607 expires_at: None,
608 metadata: None,
609 };
610
611 let mut mock = MockTokenProvider::new();
612 mock.expect_token().times(1).return_once(|| Ok(token));
613
614 let mdsc = MDSCredentials {
615 quota_project_id: None,
616 token_provider: TokenCache::new(mock),
617 };
618
619 let mut extensions = Extensions::new();
620 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
621 let (headers, entity_tag) = match cached_headers {
622 CacheableResource::New { entity_tag, data } => (data, entity_tag),
623 CacheableResource::NotModified => unreachable!("expecting new headers"),
624 };
625 let token = headers.get(AUTHORIZATION).unwrap();
626 assert_eq!(headers.len(), 1, "{headers:?}");
627 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
628 assert!(token.is_sensitive());
629
630 extensions.insert(entity_tag);
631
632 let cached_headers = mdsc.headers(extensions).await?;
633
634 match cached_headers {
635 CacheableResource::New { .. } => unreachable!("expecting new headers"),
636 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
637 };
638 Ok(())
639 }
640
641 #[tokio::test]
642 #[parallel]
643 async fn access_token_success() -> TestResult {
644 let server = Server::run();
645 let response = MDSTokenResponse {
646 access_token: "test-access-token".to_string(),
647 expires_in: Some(3600),
648 token_type: "Bearer".to_string(),
649 };
650 server.expect(
651 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
652 .respond_with(json_encoded(response)),
653 );
654
655 let creds = Builder::default()
656 .with_endpoint(format!("http://{}", server.addr()))
657 .without_access_boundary()
658 .build_access_token_credentials()
659 .unwrap();
660
661 let access_token = creds.access_token().await.unwrap();
662 assert_eq!(access_token.token, "test-access-token");
663
664 Ok(())
665 }
666
667 #[tokio::test]
668 #[parallel]
669 async fn headers_failure() {
670 let mut mock = MockTokenProvider::new();
671 mock.expect_token()
672 .times(1)
673 .return_once(|| Err(errors::non_retryable_from_str("fail")));
674
675 let mdsc = MDSCredentials {
676 quota_project_id: None,
677 token_provider: TokenCache::new(mock),
678 };
679 let result = mdsc.headers(Extensions::new()).await;
680 assert!(result.is_err(), "{result:?}");
681 }
682
683 #[test]
684 #[parallel]
685 fn error_message_with_adc() {
686 let provider = MDSAccessTokenProvider::builder()
687 .created_by_adc(true)
688 .build();
689
690 let want = MDS_NOT_FOUND_ERROR;
691 let got = provider.error_message();
692 assert!(got.contains(want), "{got}, {provider:?}");
693 }
694
695 #[test_case(false, false)]
696 #[test_case(false, true)]
697 #[test_case(true, true)]
698 fn error_message_without_adc(adc: bool, overridden: bool) {
699 let endpoint = if overridden {
700 Some("http://127.0.0.1")
701 } else {
702 None
703 };
704 let provider = MDSAccessTokenProvider::builder()
705 .endpoint(endpoint)
706 .created_by_adc(adc)
707 .build();
708
709 let not_want = MDS_NOT_FOUND_ERROR;
710 let got = provider.error_message();
711 assert!(!got.contains(not_want), "{got}, {provider:?}");
712 }
713
714 #[tokio::test]
715 #[serial]
716 async fn adc_no_mds() -> TestResult {
717 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
718 return Ok(());
720 };
721
722 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
723 assert!(
724 original_err.to_string().contains("application-default"),
725 "display={err}, debug={err:?}"
726 );
727
728 Ok(())
729 }
730
731 #[tokio::test]
732 #[serial]
733 async fn adc_overridden_mds() -> TestResult {
734 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
735
736 let err = Builder::from_adc()
737 .build_token_provider()
738 .token()
739 .await
740 .unwrap_err();
741
742 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
743
744 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
745 assert!(original_err.is_transient());
746 assert!(
747 !original_err.to_string().contains("application-default"),
748 "display={err}, debug={err:?}"
749 );
750 let source = find_source_error::<reqwest::Error>(&err);
751 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
752
753 Ok(())
754 }
755
756 #[tokio::test]
757 #[serial]
758 async fn builder_no_mds() -> TestResult {
759 let Err(e) = Builder::default().build_token_provider().token().await else {
760 return Ok(());
762 };
763
764 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
765 assert!(
766 !format!("{:?}", original_err.source()).contains("application-default"),
767 "{e:?}"
768 );
769
770 Ok(())
771 }
772
773 #[tokio::test]
774 #[serial]
775 async fn test_gce_metadata_host_env_var() -> TestResult {
776 let server = Server::run();
777 let scopes = ["scope1", "scope2"];
778 let response = MDSTokenResponse {
779 access_token: "test-access-token".to_string(),
780 expires_in: Some(3600),
781 token_type: "test-token-type".to_string(),
782 };
783 server.expect(
784 Expectation::matching(all_of![
785 request::path(format!("{MDS_DEFAULT_URI}/token")),
786 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
787 ])
788 .respond_with(json_encoded(response)),
789 );
790
791 let addr = server.addr().to_string();
792 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
793 let mdsc = Builder::default()
794 .with_scopes(["scope1", "scope2"])
795 .without_access_boundary()
796 .build()
797 .unwrap();
798 let headers = mdsc.headers(Extensions::new()).await.unwrap();
799 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
800
801 assert_eq!(
802 get_token_from_headers(headers).unwrap(),
803 "test-access-token"
804 );
805 Ok(())
806 }
807
808 #[tokio::test]
809 #[parallel]
810 async fn headers_success_with_quota_project() -> TestResult {
811 let server = Server::run();
812 let scopes = ["scope1", "scope2"];
813 let response = MDSTokenResponse {
814 access_token: "test-access-token".to_string(),
815 expires_in: Some(3600),
816 token_type: "test-token-type".to_string(),
817 };
818 server.expect(
819 Expectation::matching(all_of![
820 request::path(format!("{MDS_DEFAULT_URI}/token")),
821 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
822 ])
823 .respond_with(json_encoded(response)),
824 );
825
826 let mdsc = Builder::default()
827 .with_scopes(["scope1", "scope2"])
828 .with_endpoint(format!("http://{}", server.addr()))
829 .with_quota_project_id("test-project")
830 .without_access_boundary()
831 .build()?;
832
833 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
834 let token = headers.get(AUTHORIZATION).unwrap();
835 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
836
837 assert_eq!(headers.len(), 2, "{headers:?}");
838 assert_eq!(
839 token,
840 HeaderValue::from_static("test-token-type test-access-token")
841 );
842 assert!(token.is_sensitive());
843 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
844 assert!(!quota_project.is_sensitive());
845
846 Ok(())
847 }
848
849 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
850 #[parallel]
851 async fn token_caching() -> TestResult {
852 let mut server = Server::run();
853 let scopes = vec!["scope1".to_string()];
854 let response = MDSTokenResponse {
855 access_token: "test-access-token".to_string(),
856 expires_in: Some(3600),
857 token_type: "test-token-type".to_string(),
858 };
859 server.expect(
860 Expectation::matching(all_of![
861 request::path(format!("{MDS_DEFAULT_URI}/token")),
862 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
863 ])
864 .times(1)
865 .respond_with(json_encoded(response)),
866 );
867
868 let mdsc = Builder::default()
869 .with_scopes(scopes)
870 .with_endpoint(format!("http://{}", server.addr()))
871 .build()?;
872 let headers = mdsc.headers(Extensions::new()).await?;
873 assert_eq!(
874 get_token_from_headers(headers).unwrap(),
875 "test-access-token"
876 );
877 let headers = mdsc.headers(Extensions::new()).await?;
878 assert_eq!(
879 get_token_from_headers(headers).unwrap(),
880 "test-access-token"
881 );
882
883 server.verify_and_clear();
885
886 Ok(())
887 }
888
889 #[tokio::test(start_paused = true)]
890 #[parallel]
891 async fn token_provider_full() -> TestResult {
892 let server = Server::run();
893 let scopes = vec!["scope1".to_string()];
894 let response = MDSTokenResponse {
895 access_token: "test-access-token".to_string(),
896 expires_in: Some(3600),
897 token_type: "test-token-type".to_string(),
898 };
899 server.expect(
900 Expectation::matching(all_of![
901 request::path(format!("{MDS_DEFAULT_URI}/token")),
902 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
903 ])
904 .respond_with(json_encoded(response)),
905 );
906
907 let token = Builder::default()
908 .with_endpoint(format!("http://{}", server.addr()))
909 .with_scopes(scopes)
910 .build_token_provider()
911 .token()
912 .await?;
913
914 let now = tokio::time::Instant::now();
915 assert_eq!(token.token, "test-access-token");
916 assert_eq!(token.token_type, "test-token-type");
917 assert!(
918 token
919 .expires_at
920 .is_some_and(|d| d >= now + Duration::from_secs(3600))
921 );
922
923 Ok(())
924 }
925
926 #[tokio::test(start_paused = true)]
927 #[parallel]
928 async fn token_provider_full_no_scopes() -> TestResult {
929 let server = Server::run();
930 let response = MDSTokenResponse {
931 access_token: "test-access-token".to_string(),
932 expires_in: Some(3600),
933 token_type: "test-token-type".to_string(),
934 };
935 server.expect(
936 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
937 .respond_with(json_encoded(response)),
938 );
939
940 let token = Builder::default()
941 .with_endpoint(format!("http://{}", server.addr()))
942 .build_token_provider()
943 .token()
944 .await?;
945
946 let now = Instant::now();
947 assert_eq!(token.token, "test-access-token");
948 assert_eq!(token.token_type, "test-token-type");
949 assert!(
950 token
951 .expires_at
952 .is_some_and(|d| d == now + Duration::from_secs(3600))
953 );
954
955 Ok(())
956 }
957
958 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
959 #[parallel]
960 async fn credential_provider_full() -> TestResult {
961 let server = Server::run();
962 let scopes = vec!["scope1".to_string()];
963 let response = MDSTokenResponse {
964 access_token: "test-access-token".to_string(),
965 expires_in: None,
966 token_type: "test-token-type".to_string(),
967 };
968 server.expect(
969 Expectation::matching(all_of![
970 request::path(format!("{MDS_DEFAULT_URI}/token")),
971 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
972 ])
973 .respond_with(json_encoded(response)),
974 );
975
976 let mdsc = Builder::default()
977 .with_endpoint(format!("http://{}", server.addr()))
978 .with_scopes(scopes)
979 .without_access_boundary()
980 .build()?;
981 let headers = mdsc.headers(Extensions::new()).await?;
982 assert_eq!(
983 get_token_from_headers(headers.clone()).unwrap(),
984 "test-access-token"
985 );
986 assert_eq!(
987 get_token_type_from_headers(headers).unwrap(),
988 "test-token-type"
989 );
990
991 Ok(())
992 }
993
994 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
995 #[parallel]
996 async fn credentials_headers_retryable_error() -> TestResult {
997 let server = Server::run();
998 let scopes = vec!["scope1".to_string()];
999 server.expect(
1000 Expectation::matching(all_of![
1001 request::path(format!("{MDS_DEFAULT_URI}/token")),
1002 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1003 ])
1004 .respond_with(status_code(503)),
1005 );
1006
1007 let mdsc = Builder::default()
1008 .with_endpoint(format!("http://{}", server.addr()))
1009 .with_scopes(scopes)
1010 .without_access_boundary()
1011 .build()?;
1012 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1013 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1014 assert!(original_err.is_transient());
1015 let source = find_source_error::<reqwest::Error>(&err);
1016 assert!(
1017 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1018 "{err:?}"
1019 );
1020
1021 Ok(())
1022 }
1023
1024 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1025 #[parallel]
1026 async fn credentials_headers_nonretryable_error() -> TestResult {
1027 let server = Server::run();
1028 let scopes = vec!["scope1".to_string()];
1029 server.expect(
1030 Expectation::matching(all_of![
1031 request::path(format!("{MDS_DEFAULT_URI}/token")),
1032 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1033 ])
1034 .respond_with(status_code(401)),
1035 );
1036
1037 let mdsc = Builder::default()
1038 .with_endpoint(format!("http://{}", server.addr()))
1039 .with_scopes(scopes)
1040 .without_access_boundary()
1041 .build()?;
1042
1043 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1044 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1045 assert!(!original_err.is_transient());
1046 let source = find_source_error::<reqwest::Error>(&err);
1047 assert!(
1048 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1049 "{err:?}"
1050 );
1051
1052 Ok(())
1053 }
1054
1055 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1056 #[parallel]
1057 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1058 let server = Server::run();
1059 let scopes = vec!["scope1".to_string()];
1060 server.expect(
1061 Expectation::matching(all_of![
1062 request::path(format!("{MDS_DEFAULT_URI}/token")),
1063 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1064 ])
1065 .respond_with(json_encoded("bad json")),
1066 );
1067
1068 let mdsc = Builder::default()
1069 .with_endpoint(format!("http://{}", server.addr()))
1070 .with_scopes(scopes)
1071 .without_access_boundary()
1072 .build()?;
1073
1074 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1075 assert!(!e.is_transient());
1076
1077 Ok(())
1078 }
1079
1080 #[tokio::test]
1081 #[parallel]
1082 async fn get_default_universe_domain_success() -> TestResult {
1083 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1084 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1085 Ok(())
1086 }
1087
1088 #[tokio::test]
1089 #[parallel]
1090 async fn get_mds_signer() -> TestResult {
1091 let server = Server::run();
1092 server.expect(
1093 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1094 .respond_with(json_encoded(MDSTokenResponse {
1095 access_token: "test-access-token".to_string(),
1096 expires_in: None,
1097 token_type: "Bearer".to_string(),
1098 })),
1099 );
1100 server.expect(
1101 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1102 .respond_with(status_code(200).body("test-client-email")),
1103 );
1104 server.expect(
1105 Expectation::matching(all_of![
1106 request::method_path(
1107 "POST",
1108 "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1109 ),
1110 request::headers(contains(("authorization", "Bearer test-access-token"))),
1111 ])
1112 .respond_with(json_encoded(json!({
1113 "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1114 }))),
1115 );
1116
1117 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1118
1119 let signer = Builder::default()
1120 .with_endpoint(&endpoint)
1121 .maybe_iam_endpoint_override(Some(endpoint))
1122 .without_access_boundary()
1123 .build_signer()?;
1124
1125 let client_email = signer.client_email().await?;
1126 assert_eq!(client_email, "test-client-email");
1127
1128 let signature = signer.sign(b"test").await?;
1129 assert_eq!(signature.as_ref(), b"signed_blob");
1130
1131 Ok(())
1132 }
1133
1134 #[tokio::test]
1135 #[parallel]
1136 #[cfg(google_cloud_unstable_trusted_boundaries)]
1137 async fn e2e_access_boundary() -> TestResult {
1138 use crate::credentials::tests::get_access_boundary_from_headers;
1139
1140 let server = Server::run();
1141 server.expect(
1142 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1143 .respond_with(json_encoded(MDSTokenResponse {
1144 access_token: "test-access-token".to_string(),
1145 expires_in: None,
1146 token_type: "Bearer".to_string(),
1147 })),
1148 );
1149 server.expect(
1150 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1151 .respond_with(status_code(200).body("test-client-email")),
1152 );
1153 server.expect(
1154 Expectation::matching(all_of![
1155 request::method_path(
1156 "GET",
1157 "/v1/projects/-/serviceAccounts/test-client-email/allowedLocations"
1158 ),
1159 request::headers(contains(("authorization", "Bearer test-access-token"))),
1160 ])
1161 .respond_with(json_encoded(json!({
1162 "locations": ["us-central1", "us-east1"],
1163 "encodedLocations": "0x1234"
1164 }))),
1165 );
1166
1167 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1168
1169 let creds = Builder::default()
1170 .with_endpoint(&endpoint)
1171 .maybe_iam_endpoint_override(Some(endpoint))
1172 .build_credentials()?;
1173
1174 creds.wait_for_boundary().await;
1176
1177 let headers = creds.headers(Extensions::new()).await?;
1178 let token = get_token_from_headers(headers.clone());
1179 let access_boundary = get_access_boundary_from_headers(headers);
1180 assert!(token.is_some(), "should have some token: {token:?}");
1181 assert_eq!(
1182 access_boundary.as_deref(),
1183 Some("0x1234"),
1184 "should be 0x1234 but found: {access_boundary:?}"
1185 );
1186
1187 Ok(())
1188 }
1189}