1use crate::access_boundary::CredentialsWithAccessBoundary;
76use crate::credentials::dynamic::{AccessTokenCredentialsProvider, CredentialsProvider};
77use crate::credentials::{AccessToken, AccessTokenCredentials, CacheableResource, Credentials};
78use crate::headers_util::AuthHeadersBuilder;
79use crate::mds::client::Client as MDSClient;
80use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
81use crate::token::{CachedTokenProvider, Token, TokenProvider};
82use crate::token_cache::TokenCache;
83use crate::{BuildResult, Result};
84use async_trait::async_trait;
85use google_cloud_gax::backoff_policy::BackoffPolicyArg;
86use google_cloud_gax::error::CredentialsError;
87use google_cloud_gax::retry_policy::RetryPolicyArg;
88use google_cloud_gax::retry_throttler::RetryThrottlerArg;
89use http::{Extensions, HeaderMap};
90use std::default::Default;
91use std::sync::Arc;
92
93const MDS_NOT_FOUND_ERROR: &str = concat!(
95 "Could not fetch an auth token to authenticate with Google Cloud. ",
96 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
97 "and you have not configured local credentials for development and testing. ",
98 "To setup local credentials, run `gcloud auth application-default login`. ",
99 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
100);
101
102#[derive(Debug)]
103struct MDSCredentials<T>
104where
105 T: CachedTokenProvider,
106{
107 quota_project_id: Option<String>,
108 token_provider: T,
109}
110
111#[derive(Debug)]
123pub struct Builder {
124 endpoint: Option<String>,
125 quota_project_id: Option<String>,
126 scopes: Option<Vec<String>>,
127 created_by_adc: bool,
128 retry_builder: RetryTokenProviderBuilder,
129 iam_endpoint_override: Option<String>,
130 is_access_boundary_enabled: bool,
131}
132
133impl Default for Builder {
134 fn default() -> Self {
135 Self {
136 endpoint: None,
137 quota_project_id: None,
138 scopes: None,
139 created_by_adc: false,
140 retry_builder: RetryTokenProviderBuilder::default(),
141 iam_endpoint_override: None,
142 is_access_boundary_enabled: true,
143 }
144 }
145}
146
147impl Builder {
148 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
163 self.endpoint = Some(endpoint.into());
164 self
165 }
166
167 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
176 self.quota_project_id = Some(quota_project_id.into());
177 self
178 }
179
180 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
189 where
190 I: IntoIterator<Item = S>,
191 S: Into<String>,
192 {
193 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
194 self
195 }
196
197 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
212 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
213 self
214 }
215
216 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
232 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
233 self
234 }
235
236 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
257 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
258 self
259 }
260
261 #[cfg(test)]
262 fn maybe_iam_endpoint_override(mut self, iam_endpoint_override: Option<String>) -> Self {
263 self.iam_endpoint_override = iam_endpoint_override;
264 self
265 }
266
267 #[cfg(test)]
268 fn without_access_boundary(mut self) -> Self {
269 self.is_access_boundary_enabled = false;
270 self
271 }
272
273 pub(crate) fn from_adc() -> Self {
275 Self {
276 created_by_adc: true,
277 ..Default::default()
278 }
279 }
280
281 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
282 let tp = MDSAccessTokenProvider::builder()
283 .endpoint(self.endpoint)
284 .maybe_scopes(self.scopes)
285 .created_by_adc(self.created_by_adc)
286 .build();
287 self.retry_builder.build(tp)
288 }
289
290 pub fn build(self) -> BuildResult<Credentials> {
292 Ok(self.build_credentials()?.into())
293 }
294
295 pub fn build_access_token_credentials(self) -> BuildResult<AccessTokenCredentials> {
311 Ok(self.build_credentials()?.into())
312 }
313
314 fn build_credentials(
315 self,
316 ) -> BuildResult<CredentialsWithAccessBoundary<MDSCredentials<TokenCache>>> {
317 let iam_endpoint = self.iam_endpoint_override.clone();
318 let is_access_boundary_enabled = self.is_access_boundary_enabled;
319 let mds_client = MDSClient::new(self.endpoint.clone());
320 let mdsc = MDSCredentials {
321 quota_project_id: self.quota_project_id.clone(),
322 token_provider: TokenCache::new(self.build_token_provider()),
323 };
324 if !is_access_boundary_enabled {
325 return Ok(CredentialsWithAccessBoundary::new_no_op(mdsc));
326 }
327 Ok(CredentialsWithAccessBoundary::new_for_mds(
328 mdsc,
329 mds_client,
330 iam_endpoint,
331 ))
332 }
333
334 pub fn build_signer(self) -> BuildResult<crate::signer::Signer> {
351 let client = MDSClient::new(self.endpoint.clone());
352 let iam_endpoint = self.iam_endpoint_override.clone();
353 let credentials = self.build()?;
354 let signing_provider = crate::signer::mds::MDSSigner::new(client, credentials);
355 let signing_provider = iam_endpoint
356 .iter()
357 .fold(signing_provider, |signing_provider, endpoint| {
358 signing_provider.with_iam_endpoint_override(endpoint)
359 });
360 Ok(crate::signer::Signer {
361 inner: Arc::new(signing_provider),
362 })
363 }
364}
365
366#[async_trait::async_trait]
367impl<T> CredentialsProvider for MDSCredentials<T>
368where
369 T: CachedTokenProvider,
370{
371 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
372 let token = self.token_provider.token(extensions).await?;
373
374 AuthHeadersBuilder::new(&token)
375 .maybe_quota_project_id(self.quota_project_id.as_deref())
376 .build()
377 }
378}
379
380#[async_trait::async_trait]
381impl<T> AccessTokenCredentialsProvider for MDSCredentials<T>
382where
383 T: CachedTokenProvider,
384{
385 async fn access_token(&self) -> Result<AccessToken> {
386 let token = self.token_provider.token(Extensions::new()).await?;
387 token.into()
388 }
389}
390
391#[derive(Debug, Default)]
392struct MDSAccessTokenProviderBuilder {
393 scopes: Option<Vec<String>>,
394 endpoint: Option<String>,
395 created_by_adc: bool,
396}
397
398impl MDSAccessTokenProviderBuilder {
399 fn build(self) -> MDSAccessTokenProvider {
400 MDSAccessTokenProvider {
401 client: MDSClient::new(self.endpoint),
402 scopes: self.scopes,
403 created_by_adc: self.created_by_adc,
404 }
405 }
406
407 fn maybe_scopes(mut self, v: Option<Vec<String>>) -> Self {
408 self.scopes = v;
409 self
410 }
411
412 fn endpoint<T>(mut self, v: Option<T>) -> Self
413 where
414 T: Into<String>,
415 {
416 self.endpoint = v.map(Into::into);
417 self
418 }
419
420 fn created_by_adc(mut self, v: bool) -> Self {
421 self.created_by_adc = v;
422 self
423 }
424}
425
426#[derive(Debug, Clone)]
427struct MDSAccessTokenProvider {
428 scopes: Option<Vec<String>>,
429 client: MDSClient,
430 created_by_adc: bool,
431}
432
433impl MDSAccessTokenProvider {
434 fn builder() -> MDSAccessTokenProviderBuilder {
435 MDSAccessTokenProviderBuilder::default()
436 }
437
438 fn error_message(&self) -> &str {
446 if self.use_adc_message() {
447 MDS_NOT_FOUND_ERROR
448 } else {
449 "failed to fetch token"
450 }
451 }
452
453 fn use_adc_message(&self) -> bool {
454 self.created_by_adc && self.client.is_default_endpoint
455 }
456}
457
458#[async_trait]
459impl TokenProvider for MDSAccessTokenProvider {
460 async fn token(&self) -> Result<Token> {
461 self.client
462 .access_token(self.scopes.clone())
463 .await
464 .map_err(|e| CredentialsError::new(e.is_transient(), self.error_message(), e))
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
472 use crate::credentials::QUOTA_PROJECT_KEY;
473 use crate::credentials::tests::{
474 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
475 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
476 get_token_type_from_headers,
477 };
478 use crate::errors;
479 use crate::errors::CredentialsError;
480 use crate::mds::client::MDSTokenResponse;
481 use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
482 use crate::token::tests::MockTokenProvider;
483 use base64::{Engine, prelude::BASE64_STANDARD};
484 use http::HeaderValue;
485 use http::header::AUTHORIZATION;
486 use httptest::cycle;
487 use httptest::matchers::{all_of, contains, request, url_decoded};
488 use httptest::responders::{json_encoded, status_code};
489 use httptest::{Expectation, Server};
490 use reqwest::StatusCode;
491 use scoped_env::ScopedEnv;
492 use serde_json::json;
493 use serial_test::{parallel, serial};
494 use std::error::Error;
495 use std::time::Duration;
496 use test_case::test_case;
497 use tokio::time::Instant;
498 use url::Url;
499
500 type TestResult = anyhow::Result<()>;
501
502 #[tokio::test]
503 #[parallel]
504 async fn test_mds_retries_on_transient_failures() -> TestResult {
505 let mut server = Server::run();
506 server.expect(
507 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
508 .times(3)
509 .respond_with(status_code(503)),
510 );
511
512 let provider = Builder::default()
513 .with_endpoint(format!("http://{}", server.addr()))
514 .with_retry_policy(get_mock_auth_retry_policy(3))
515 .with_backoff_policy(get_mock_backoff_policy())
516 .with_retry_throttler(get_mock_retry_throttler())
517 .build_token_provider();
518
519 let err = provider.token().await.unwrap_err();
520 assert!(err.is_transient(), "{err:?}");
521 server.verify_and_clear();
522 Ok(())
523 }
524
525 #[tokio::test]
526 #[parallel]
527 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
528 let mut server = Server::run();
529 server.expect(
530 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
531 .times(1)
532 .respond_with(status_code(401)),
533 );
534
535 let provider = Builder::default()
536 .with_endpoint(format!("http://{}", server.addr()))
537 .with_retry_policy(get_mock_auth_retry_policy(1))
538 .with_backoff_policy(get_mock_backoff_policy())
539 .with_retry_throttler(get_mock_retry_throttler())
540 .build_token_provider();
541
542 let err = provider.token().await.unwrap_err();
543 assert!(!err.is_transient());
544 server.verify_and_clear();
545 Ok(())
546 }
547
548 #[tokio::test]
549 #[parallel]
550 async fn test_mds_retries_for_success() -> TestResult {
551 let mut server = Server::run();
552 let response = MDSTokenResponse {
553 access_token: "test-access-token".to_string(),
554 expires_in: Some(3600),
555 token_type: "test-token-type".to_string(),
556 };
557
558 server.expect(
559 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
560 .times(3)
561 .respond_with(cycle![
562 status_code(503).body("try-again"),
563 status_code(503).body("try-again"),
564 status_code(200)
565 .append_header("Content-Type", "application/json")
566 .body(serde_json::to_string(&response).unwrap()),
567 ]),
568 );
569
570 let provider = Builder::default()
571 .with_endpoint(format!("http://{}", server.addr()))
572 .with_retry_policy(get_mock_auth_retry_policy(3))
573 .with_backoff_policy(get_mock_backoff_policy())
574 .with_retry_throttler(get_mock_retry_throttler())
575 .build_token_provider();
576
577 let token = provider.token().await?;
578 assert_eq!(token.token, "test-access-token");
579
580 server.verify_and_clear();
581 Ok(())
582 }
583
584 #[test]
585 #[parallel]
586 fn validate_default_endpoint_urls() {
587 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
588 assert!(
589 default_endpoint_address.is_ok(),
590 "{default_endpoint_address:?}"
591 );
592
593 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
594 assert!(token_endpoint_address.is_ok(), "{token_endpoint_address:?}");
595 }
596
597 #[tokio::test]
598 #[parallel]
599 async fn headers_success() -> TestResult {
600 let token = Token {
601 token: "test-token".to_string(),
602 token_type: "Bearer".to_string(),
603 expires_at: None,
604 metadata: None,
605 };
606
607 let mut mock = MockTokenProvider::new();
608 mock.expect_token().times(1).return_once(|| Ok(token));
609
610 let mdsc = MDSCredentials {
611 quota_project_id: None,
612 token_provider: TokenCache::new(mock),
613 };
614
615 let mut extensions = Extensions::new();
616 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
617 let (headers, entity_tag) = match cached_headers {
618 CacheableResource::New { entity_tag, data } => (data, entity_tag),
619 CacheableResource::NotModified => unreachable!("expecting new headers"),
620 };
621 let token = headers.get(AUTHORIZATION).unwrap();
622 assert_eq!(headers.len(), 1, "{headers:?}");
623 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
624 assert!(token.is_sensitive());
625
626 extensions.insert(entity_tag);
627
628 let cached_headers = mdsc.headers(extensions).await?;
629
630 match cached_headers {
631 CacheableResource::New { .. } => unreachable!("expecting new headers"),
632 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
633 };
634 Ok(())
635 }
636
637 #[tokio::test]
638 #[parallel]
639 async fn access_token_success() -> TestResult {
640 let server = Server::run();
641 let response = MDSTokenResponse {
642 access_token: "test-access-token".to_string(),
643 expires_in: Some(3600),
644 token_type: "Bearer".to_string(),
645 };
646 server.expect(
647 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
648 .respond_with(json_encoded(response)),
649 );
650
651 let creds = Builder::default()
652 .with_endpoint(format!("http://{}", server.addr()))
653 .without_access_boundary()
654 .build_access_token_credentials()
655 .unwrap();
656
657 let access_token = creds.access_token().await.unwrap();
658 assert_eq!(access_token.token, "test-access-token");
659
660 Ok(())
661 }
662
663 #[tokio::test]
664 #[parallel]
665 async fn headers_failure() {
666 let mut mock = MockTokenProvider::new();
667 mock.expect_token()
668 .times(1)
669 .return_once(|| Err(errors::non_retryable_from_str("fail")));
670
671 let mdsc = MDSCredentials {
672 quota_project_id: None,
673 token_provider: TokenCache::new(mock),
674 };
675 let result = mdsc.headers(Extensions::new()).await;
676 assert!(result.is_err(), "{result:?}");
677 }
678
679 #[test]
680 #[parallel]
681 fn error_message_with_adc() {
682 let provider = MDSAccessTokenProvider::builder()
683 .created_by_adc(true)
684 .build();
685
686 let want = MDS_NOT_FOUND_ERROR;
687 let got = provider.error_message();
688 assert!(got.contains(want), "{got}, {provider:?}");
689 }
690
691 #[test_case(false, false)]
692 #[test_case(false, true)]
693 #[test_case(true, true)]
694 fn error_message_without_adc(adc: bool, overridden: bool) {
695 let endpoint = if overridden {
696 Some("http://127.0.0.1")
697 } else {
698 None
699 };
700 let provider = MDSAccessTokenProvider::builder()
701 .endpoint(endpoint)
702 .created_by_adc(adc)
703 .build();
704
705 let not_want = MDS_NOT_FOUND_ERROR;
706 let got = provider.error_message();
707 assert!(!got.contains(not_want), "{got}, {provider:?}");
708 }
709
710 #[tokio::test]
711 #[serial]
712 async fn adc_no_mds() -> TestResult {
713 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
714 return Ok(());
716 };
717
718 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
719 assert!(
720 original_err.to_string().contains("application-default"),
721 "display={err}, debug={err:?}"
722 );
723
724 Ok(())
725 }
726
727 #[tokio::test]
728 #[serial]
729 async fn adc_overridden_mds() -> TestResult {
730 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
731
732 let err = Builder::from_adc()
733 .build_token_provider()
734 .token()
735 .await
736 .unwrap_err();
737
738 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
739
740 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
741 assert!(original_err.is_transient());
742 assert!(
743 !original_err.to_string().contains("application-default"),
744 "display={err}, debug={err:?}"
745 );
746 let source = find_source_error::<reqwest::Error>(&err);
747 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
748
749 Ok(())
750 }
751
752 #[tokio::test]
753 #[serial]
754 async fn builder_no_mds() -> TestResult {
755 let Err(e) = Builder::default().build_token_provider().token().await else {
756 return Ok(());
758 };
759
760 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
761 assert!(
762 !format!("{:?}", original_err.source()).contains("application-default"),
763 "{e:?}"
764 );
765
766 Ok(())
767 }
768
769 #[tokio::test]
770 #[serial]
771 async fn test_gce_metadata_host_env_var() -> TestResult {
772 let server = Server::run();
773 let scopes = ["scope1", "scope2"];
774 let response = MDSTokenResponse {
775 access_token: "test-access-token".to_string(),
776 expires_in: Some(3600),
777 token_type: "test-token-type".to_string(),
778 };
779 server.expect(
780 Expectation::matching(all_of![
781 request::path(format!("{MDS_DEFAULT_URI}/token")),
782 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
783 ])
784 .respond_with(json_encoded(response)),
785 );
786
787 let addr = server.addr().to_string();
788 let _e = ScopedEnv::set(GCE_METADATA_HOST_ENV_VAR, &addr);
789 let mdsc = Builder::default()
790 .with_scopes(["scope1", "scope2"])
791 .without_access_boundary()
792 .build()
793 .unwrap();
794 let headers = mdsc.headers(Extensions::new()).await.unwrap();
795 let _e = ScopedEnv::remove(GCE_METADATA_HOST_ENV_VAR);
796
797 assert_eq!(
798 get_token_from_headers(headers).unwrap(),
799 "test-access-token"
800 );
801 Ok(())
802 }
803
804 #[tokio::test]
805 #[parallel]
806 async fn headers_success_with_quota_project() -> TestResult {
807 let server = Server::run();
808 let scopes = ["scope1", "scope2"];
809 let response = MDSTokenResponse {
810 access_token: "test-access-token".to_string(),
811 expires_in: Some(3600),
812 token_type: "test-token-type".to_string(),
813 };
814 server.expect(
815 Expectation::matching(all_of![
816 request::path(format!("{MDS_DEFAULT_URI}/token")),
817 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
818 ])
819 .respond_with(json_encoded(response)),
820 );
821
822 let mdsc = Builder::default()
823 .with_scopes(["scope1", "scope2"])
824 .with_endpoint(format!("http://{}", server.addr()))
825 .with_quota_project_id("test-project")
826 .without_access_boundary()
827 .build()?;
828
829 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
830 let token = headers.get(AUTHORIZATION).unwrap();
831 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
832
833 assert_eq!(headers.len(), 2, "{headers:?}");
834 assert_eq!(
835 token,
836 HeaderValue::from_static("test-token-type test-access-token")
837 );
838 assert!(token.is_sensitive());
839 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
840 assert!(!quota_project.is_sensitive());
841
842 Ok(())
843 }
844
845 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
846 #[parallel]
847 async fn token_caching() -> TestResult {
848 let mut server = Server::run();
849 let scopes = vec!["scope1".to_string()];
850 let response = MDSTokenResponse {
851 access_token: "test-access-token".to_string(),
852 expires_in: Some(3600),
853 token_type: "test-token-type".to_string(),
854 };
855 server.expect(
856 Expectation::matching(all_of![
857 request::path(format!("{MDS_DEFAULT_URI}/token")),
858 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
859 ])
860 .times(1)
861 .respond_with(json_encoded(response)),
862 );
863
864 let mdsc = Builder::default()
865 .with_scopes(scopes)
866 .with_endpoint(format!("http://{}", server.addr()))
867 .build()?;
868 let headers = mdsc.headers(Extensions::new()).await?;
869 assert_eq!(
870 get_token_from_headers(headers).unwrap(),
871 "test-access-token"
872 );
873 let headers = mdsc.headers(Extensions::new()).await?;
874 assert_eq!(
875 get_token_from_headers(headers).unwrap(),
876 "test-access-token"
877 );
878
879 server.verify_and_clear();
881
882 Ok(())
883 }
884
885 #[tokio::test(start_paused = true)]
886 #[parallel]
887 async fn token_provider_full() -> TestResult {
888 let server = Server::run();
889 let scopes = vec!["scope1".to_string()];
890 let response = MDSTokenResponse {
891 access_token: "test-access-token".to_string(),
892 expires_in: Some(3600),
893 token_type: "test-token-type".to_string(),
894 };
895 server.expect(
896 Expectation::matching(all_of![
897 request::path(format!("{MDS_DEFAULT_URI}/token")),
898 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
899 ])
900 .respond_with(json_encoded(response)),
901 );
902
903 let token = Builder::default()
904 .with_endpoint(format!("http://{}", server.addr()))
905 .with_scopes(scopes)
906 .build_token_provider()
907 .token()
908 .await?;
909
910 let now = tokio::time::Instant::now();
911 assert_eq!(token.token, "test-access-token");
912 assert_eq!(token.token_type, "test-token-type");
913 assert!(
914 token
915 .expires_at
916 .is_some_and(|d| d >= now + Duration::from_secs(3600))
917 );
918
919 Ok(())
920 }
921
922 #[tokio::test(start_paused = true)]
923 #[parallel]
924 async fn token_provider_full_no_scopes() -> TestResult {
925 let server = Server::run();
926 let response = MDSTokenResponse {
927 access_token: "test-access-token".to_string(),
928 expires_in: Some(3600),
929 token_type: "test-token-type".to_string(),
930 };
931 server.expect(
932 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
933 .respond_with(json_encoded(response)),
934 );
935
936 let token = Builder::default()
937 .with_endpoint(format!("http://{}", server.addr()))
938 .build_token_provider()
939 .token()
940 .await?;
941
942 let now = Instant::now();
943 assert_eq!(token.token, "test-access-token");
944 assert_eq!(token.token_type, "test-token-type");
945 assert!(
946 token
947 .expires_at
948 .is_some_and(|d| d == now + Duration::from_secs(3600))
949 );
950
951 Ok(())
952 }
953
954 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
955 #[parallel]
956 async fn credential_provider_full() -> TestResult {
957 let server = Server::run();
958 let scopes = vec!["scope1".to_string()];
959 let response = MDSTokenResponse {
960 access_token: "test-access-token".to_string(),
961 expires_in: None,
962 token_type: "test-token-type".to_string(),
963 };
964 server.expect(
965 Expectation::matching(all_of![
966 request::path(format!("{MDS_DEFAULT_URI}/token")),
967 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
968 ])
969 .respond_with(json_encoded(response)),
970 );
971
972 let mdsc = Builder::default()
973 .with_endpoint(format!("http://{}", server.addr()))
974 .with_scopes(scopes)
975 .without_access_boundary()
976 .build()?;
977 let headers = mdsc.headers(Extensions::new()).await?;
978 assert_eq!(
979 get_token_from_headers(headers.clone()).unwrap(),
980 "test-access-token"
981 );
982 assert_eq!(
983 get_token_type_from_headers(headers).unwrap(),
984 "test-token-type"
985 );
986
987 Ok(())
988 }
989
990 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
991 #[parallel]
992 async fn credentials_headers_retryable_error() -> TestResult {
993 let server = Server::run();
994 let scopes = vec!["scope1".to_string()];
995 server.expect(
996 Expectation::matching(all_of![
997 request::path(format!("{MDS_DEFAULT_URI}/token")),
998 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
999 ])
1000 .respond_with(status_code(503)),
1001 );
1002
1003 let mdsc = Builder::default()
1004 .with_endpoint(format!("http://{}", server.addr()))
1005 .with_scopes(scopes)
1006 .without_access_boundary()
1007 .build()?;
1008 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1009 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1010 assert!(original_err.is_transient());
1011 let source = find_source_error::<reqwest::Error>(&err);
1012 assert!(
1013 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1014 "{err:?}"
1015 );
1016
1017 Ok(())
1018 }
1019
1020 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1021 #[parallel]
1022 async fn credentials_headers_nonretryable_error() -> TestResult {
1023 let server = Server::run();
1024 let scopes = vec!["scope1".to_string()];
1025 server.expect(
1026 Expectation::matching(all_of![
1027 request::path(format!("{MDS_DEFAULT_URI}/token")),
1028 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1029 ])
1030 .respond_with(status_code(401)),
1031 );
1032
1033 let mdsc = Builder::default()
1034 .with_endpoint(format!("http://{}", server.addr()))
1035 .with_scopes(scopes)
1036 .without_access_boundary()
1037 .build()?;
1038
1039 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1040 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1041 assert!(!original_err.is_transient());
1042 let source = find_source_error::<reqwest::Error>(&err);
1043 assert!(
1044 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1045 "{err:?}"
1046 );
1047
1048 Ok(())
1049 }
1050
1051 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1052 #[parallel]
1053 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1054 let server = Server::run();
1055 let scopes = vec!["scope1".to_string()];
1056 server.expect(
1057 Expectation::matching(all_of![
1058 request::path(format!("{MDS_DEFAULT_URI}/token")),
1059 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1060 ])
1061 .respond_with(json_encoded("bad json")),
1062 );
1063
1064 let mdsc = Builder::default()
1065 .with_endpoint(format!("http://{}", server.addr()))
1066 .with_scopes(scopes)
1067 .without_access_boundary()
1068 .build()?;
1069
1070 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1071 assert!(!e.is_transient());
1072
1073 Ok(())
1074 }
1075
1076 #[tokio::test]
1077 #[parallel]
1078 async fn get_default_universe_domain_success() -> TestResult {
1079 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1080 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1081 Ok(())
1082 }
1083
1084 #[tokio::test]
1085 #[parallel]
1086 async fn get_mds_signer() -> TestResult {
1087 let server = Server::run();
1088 server.expect(
1089 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1090 .respond_with(json_encoded(MDSTokenResponse {
1091 access_token: "test-access-token".to_string(),
1092 expires_in: None,
1093 token_type: "Bearer".to_string(),
1094 })),
1095 );
1096 server.expect(
1097 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1098 .respond_with(status_code(200).body("test-client-email")),
1099 );
1100 server.expect(
1101 Expectation::matching(all_of![
1102 request::method_path(
1103 "POST",
1104 "/v1/projects/-/serviceAccounts/test-client-email:signBlob"
1105 ),
1106 request::headers(contains(("authorization", "Bearer test-access-token"))),
1107 ])
1108 .respond_with(json_encoded(json!({
1109 "signedBlob": BASE64_STANDARD.encode("signed_blob"),
1110 }))),
1111 );
1112
1113 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1114
1115 let signer = Builder::default()
1116 .with_endpoint(&endpoint)
1117 .maybe_iam_endpoint_override(Some(endpoint))
1118 .without_access_boundary()
1119 .build_signer()?;
1120
1121 let client_email = signer.client_email().await?;
1122 assert_eq!(client_email, "test-client-email");
1123
1124 let signature = signer.sign(b"test").await?;
1125 assert_eq!(signature.as_ref(), b"signed_blob");
1126
1127 Ok(())
1128 }
1129
1130 #[tokio::test]
1131 #[parallel]
1132 #[cfg(google_cloud_unstable_trusted_boundaries)]
1133 async fn e2e_access_boundary() -> TestResult {
1134 use crate::credentials::tests::get_access_boundary_from_headers;
1135
1136 let server = Server::run();
1137 server.expect(
1138 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/token")),])
1139 .respond_with(json_encoded(MDSTokenResponse {
1140 access_token: "test-access-token".to_string(),
1141 expires_in: None,
1142 token_type: "Bearer".to_string(),
1143 })),
1144 );
1145 server.expect(
1146 Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
1147 .respond_with(status_code(200).body("test-client-email")),
1148 );
1149 server.expect(
1150 Expectation::matching(all_of![
1151 request::method_path(
1152 "GET",
1153 "/v1/projects/-/serviceAccounts/test-client-email/allowedLocations"
1154 ),
1155 request::headers(contains(("authorization", "Bearer test-access-token"))),
1156 ])
1157 .respond_with(json_encoded(json!({
1158 "locations": ["us-central1", "us-east1"],
1159 "encodedLocations": "0x1234"
1160 }))),
1161 );
1162
1163 let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
1164
1165 let creds = Builder::default()
1166 .with_endpoint(&endpoint)
1167 .maybe_iam_endpoint_override(Some(endpoint))
1168 .build_credentials()?;
1169
1170 creds.wait_for_boundary().await;
1172
1173 let headers = creds.headers(Extensions::new()).await?;
1174 let token = get_token_from_headers(headers.clone());
1175 let access_boundary = get_access_boundary_from_headers(headers);
1176 assert!(token.is_some(), "should have some token: {token:?}");
1177 assert_eq!(
1178 access_boundary.as_deref(),
1179 Some("0x1234"),
1180 "should be 0x1234 but found: {access_boundary:?}"
1181 );
1182
1183 Ok(())
1184 }
1185}