1use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97const METADATA_FLAVOR_VALUE: &str = "Google";
98const METADATA_FLAVOR: &str = "metadata-flavor";
99const METADATA_ROOT: &str = "http://metadata.google.internal";
100const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102const MDS_NOT_FOUND_ERROR: &str = concat!(
104 "Could not fetch an auth token to authenticate with Google Cloud. ",
105 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106 "and you have not configured local credentials for development and testing. ",
107 "To setup local credentials, run `gcloud auth application-default login`. ",
108 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114 T: CachedTokenProvider,
115{
116 quota_project_id: Option<String>,
117 token_provider: T,
118}
119
120#[derive(Debug, Default)]
132pub struct Builder {
133 endpoint: Option<String>,
134 quota_project_id: Option<String>,
135 scopes: Option<Vec<String>>,
136 created_by_adc: bool,
137 retry_builder: RetryTokenProviderBuilder,
138}
139
140impl Builder {
141 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
156 self.endpoint = Some(endpoint.into());
157 self
158 }
159
160 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
169 self.quota_project_id = Some(quota_project_id.into());
170 self
171 }
172
173 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
182 where
183 I: IntoIterator<Item = S>,
184 S: Into<String>,
185 {
186 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
187 self
188 }
189
190 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
205 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
206 self
207 }
208
209 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
225 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
226 self
227 }
228
229 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
250 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
251 self
252 }
253
254 pub(crate) fn from_adc() -> Self {
256 Self {
257 created_by_adc: true,
258 ..Default::default()
259 }
260 }
261
262 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
263 let final_endpoint: String;
264 let endpoint_overridden: bool;
265
266 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
268 final_endpoint = format!("http://{host_from_env}");
270 endpoint_overridden = true;
271 } else if let Some(builder_endpoint) = self.endpoint {
272 final_endpoint = builder_endpoint;
274 endpoint_overridden = true;
275 } else {
276 final_endpoint = METADATA_ROOT.to_string();
278 endpoint_overridden = false;
279 };
280
281 let tp = MDSAccessTokenProvider::builder()
282 .endpoint(final_endpoint)
283 .maybe_scopes(self.scopes)
284 .endpoint_overridden(endpoint_overridden)
285 .created_by_adc(self.created_by_adc)
286 .build();
287 self.retry_builder.build(tp)
288 }
289
290 pub fn build(self) -> BuildResult<Credentials> {
292 let mdsc = MDSCredentials {
293 quota_project_id: self.quota_project_id.clone(),
294 token_provider: TokenCache::new(self.build_token_provider()),
295 };
296 Ok(Credentials {
297 inner: Arc::new(mdsc),
298 })
299 }
300}
301
302#[async_trait::async_trait]
303impl<T> CredentialsProvider for MDSCredentials<T>
304where
305 T: CachedTokenProvider,
306{
307 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
308 let cached_token = self.token_provider.token(extensions).await?;
309 build_cacheable_headers(&cached_token, &self.quota_project_id)
310 }
311}
312
313#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
314struct MDSTokenResponse {
315 access_token: String,
316 #[serde(skip_serializing_if = "Option::is_none")]
317 expires_in: Option<u64>,
318 token_type: String,
319}
320
321#[derive(Debug, Clone, Default, Builder)]
322struct MDSAccessTokenProvider {
323 #[builder(into)]
324 scopes: Option<Vec<String>>,
325 #[builder(into)]
326 endpoint: String,
327 endpoint_overridden: bool,
328 created_by_adc: bool,
329}
330
331impl MDSAccessTokenProvider {
332 fn error_message(&self) -> &str {
340 if self.use_adc_message() {
341 MDS_NOT_FOUND_ERROR
342 } else {
343 "failed to fetch token"
344 }
345 }
346
347 fn use_adc_message(&self) -> bool {
348 self.created_by_adc && !self.endpoint_overridden
349 }
350}
351
352#[async_trait]
353impl TokenProvider for MDSAccessTokenProvider {
354 async fn token(&self) -> Result<Token> {
355 let client = Client::new();
356 let request = client
357 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
358 .header(
359 METADATA_FLAVOR,
360 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
361 );
362 let scopes = self.scopes.as_ref().map(|v| v.join(","));
365 let request = scopes
366 .into_iter()
367 .fold(request, |r, s| r.query(&[("scopes", s)]));
368
369 let response = request
374 .send()
375 .await
376 .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
377 if !response.status().is_success() {
379 let err = crate::errors::from_http_response(response, self.error_message()).await;
380 return Err(err);
381 }
382 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
383 CredentialsError::from_source(!e.is_decode(), e)
387 })?;
388 let token = Token {
389 token: response.access_token,
390 token_type: response.token_type,
391 expires_at: response
392 .expires_in
393 .map(|d| Instant::now() + Duration::from_secs(d)),
394 metadata: None,
395 };
396 Ok(token)
397 }
398}
399
400#[cfg(google_cloud_unstable_id_token)]
401pub mod idtoken {
402 use super::{
404 GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_FLAVOR, METADATA_FLAVOR_VALUE,
405 METADATA_ROOT,
406 };
407 use crate::Result;
408 use crate::credentials::CacheableResource;
409 use crate::errors::CredentialsError;
410 use crate::token::{CachedTokenProvider, Token, TokenProvider};
411 use crate::token_cache::TokenCache;
412 use crate::{
413 BuildResult,
414 credentials::idtoken::dynamic::IDTokenCredentialsProvider,
415 credentials::idtoken::{IDTokenCredentials, parse_id_token_from_str},
416 };
417 use async_trait::async_trait;
418 use http::{Extensions, HeaderValue};
419 use reqwest::Client;
420 use std::sync::Arc;
421
422 #[derive(Debug)]
423 pub(crate) struct MDSCredentials<T>
424 where
425 T: CachedTokenProvider,
426 {
427 token_provider: T,
428 }
429
430 #[async_trait]
431 impl<T> IDTokenCredentialsProvider for MDSCredentials<T>
432 where
433 T: CachedTokenProvider,
434 {
435 async fn id_token(&self) -> Result<String> {
436 let cached_token = self.token_provider.token(Extensions::new()).await?;
437 match cached_token {
438 CacheableResource::New { data, .. } => Ok(data.token),
439 CacheableResource::NotModified => {
440 Err(CredentialsError::from_msg(false, "failed to fetch token"))
441 }
442 }
443 }
444 }
445
446 #[derive(Debug, Default)]
449 pub struct Builder {
450 endpoint: Option<String>,
451 format: Option<String>,
452 licenses: Option<String>,
453 target_audience: String,
454 }
455
456 impl Builder {
457 pub fn new<S: Into<String>>(target_audience: S) -> Self {
463 Builder {
464 format: None,
465 endpoint: None,
466 licenses: None,
467 target_audience: target_audience.into(),
468 }
469 }
470
471 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
476 self.endpoint = Some(endpoint.into());
477 self
478 }
479
480 pub fn with_format<S: Into<String>>(mut self, format: S) -> Self {
488 self.format = Some(format.into());
489 self
490 }
491
492 pub fn with_licenses(mut self, licenses: bool) -> Self {
499 self.licenses = if licenses {
500 Some("TRUE".to_string())
501 } else {
502 Some("FALSE".to_string())
503 };
504 self
505 }
506
507 fn build_token_provider(self) -> MDSTokenProvider {
508 let final_endpoint: String;
509
510 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
512 final_endpoint = format!("http://{host_from_env}");
514 } else if let Some(builder_endpoint) = self.endpoint {
515 final_endpoint = builder_endpoint;
517 } else {
518 final_endpoint = METADATA_ROOT.to_string();
520 };
521
522 MDSTokenProvider {
523 format: self.format,
524 licenses: self.licenses,
525 endpoint: final_endpoint,
526 target_audience: self.target_audience,
527 }
528 }
529
530 pub fn build(self) -> BuildResult<IDTokenCredentials> {
533 let creds = MDSCredentials {
534 token_provider: TokenCache::new(self.build_token_provider()),
535 };
536 Ok(IDTokenCredentials {
537 inner: Arc::new(creds),
538 })
539 }
540 }
541
542 #[derive(Debug, Clone, Default)]
543 struct MDSTokenProvider {
544 endpoint: String,
545 format: Option<String>,
546 licenses: Option<String>,
547 target_audience: String,
548 }
549
550 #[async_trait]
551 impl TokenProvider for MDSTokenProvider {
552 async fn token(&self) -> Result<Token> {
553 let client = Client::new();
554 let audience = self.target_audience.clone();
555 let request = client
556 .get(format!("{}{}/identity", self.endpoint, MDS_DEFAULT_URI))
557 .header(
558 METADATA_FLAVOR,
559 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
560 )
561 .query(&[("audience", audience)]);
562 let request = self.format.iter().fold(request, |builder, format| {
563 builder.query(&[("format", format)])
564 });
565 let request = self.licenses.iter().fold(request, |builder, licenses| {
566 builder.query(&[("licenses", licenses)])
567 });
568
569 let response = request
570 .send()
571 .await
572 .map_err(|e| crate::errors::from_http_error(e, "failed to fetch token"))?;
573
574 if !response.status().is_success() {
575 let err =
576 crate::errors::from_http_response(response, "failed to fetch token").await;
577 return Err(err);
578 }
579
580 let token = response
581 .text()
582 .await
583 .map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?;
584
585 parse_id_token_from_str(token)
586 }
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
594 use crate::credentials::QUOTA_PROJECT_KEY;
595 use crate::credentials::tests::{
596 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
597 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
598 get_token_type_from_headers,
599 };
600 use crate::errors;
601 use crate::errors::CredentialsError;
602 use crate::token::tests::MockTokenProvider;
603 use http::HeaderValue;
604 use http::header::AUTHORIZATION;
605 use httptest::cycle;
606 use httptest::matchers::{all_of, contains, request, url_decoded};
607 use httptest::responders::{json_encoded, status_code};
608 use httptest::{Expectation, Server};
609 use reqwest::StatusCode;
610 use scoped_env::ScopedEnv;
611 use serial_test::{parallel, serial};
612 use std::error::Error;
613 use test_case::test_case;
614 use url::Url;
615
616 type TestResult = anyhow::Result<()>;
617
618 #[tokio::test]
619 #[parallel]
620 async fn test_mds_retries_on_transient_failures() -> TestResult {
621 let mut server = Server::run();
622 server.expect(
623 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
624 .times(3)
625 .respond_with(status_code(503)),
626 );
627
628 let provider = Builder::default()
629 .with_endpoint(format!("http://{}", server.addr()))
630 .with_retry_policy(get_mock_auth_retry_policy(3))
631 .with_backoff_policy(get_mock_backoff_policy())
632 .with_retry_throttler(get_mock_retry_throttler())
633 .build_token_provider();
634
635 let err = provider.token().await.unwrap_err();
636 assert!(!err.is_transient());
637 server.verify_and_clear();
638 Ok(())
639 }
640
641 #[tokio::test]
642 #[parallel]
643 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
644 let mut server = Server::run();
645 server.expect(
646 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
647 .times(1)
648 .respond_with(status_code(401)),
649 );
650
651 let provider = Builder::default()
652 .with_endpoint(format!("http://{}", server.addr()))
653 .with_retry_policy(get_mock_auth_retry_policy(1))
654 .with_backoff_policy(get_mock_backoff_policy())
655 .with_retry_throttler(get_mock_retry_throttler())
656 .build_token_provider();
657
658 let err = provider.token().await.unwrap_err();
659 assert!(!err.is_transient());
660 server.verify_and_clear();
661 Ok(())
662 }
663
664 #[tokio::test]
665 #[parallel]
666 async fn test_mds_retries_for_success() -> TestResult {
667 let mut server = Server::run();
668 let response = MDSTokenResponse {
669 access_token: "test-access-token".to_string(),
670 expires_in: Some(3600),
671 token_type: "test-token-type".to_string(),
672 };
673
674 server.expect(
675 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
676 .times(3)
677 .respond_with(cycle![
678 status_code(503).body("try-again"),
679 status_code(503).body("try-again"),
680 status_code(200)
681 .append_header("Content-Type", "application/json")
682 .body(serde_json::to_string(&response).unwrap()),
683 ]),
684 );
685
686 let provider = Builder::default()
687 .with_endpoint(format!("http://{}", server.addr()))
688 .with_retry_policy(get_mock_auth_retry_policy(3))
689 .with_backoff_policy(get_mock_backoff_policy())
690 .with_retry_throttler(get_mock_retry_throttler())
691 .build_token_provider();
692
693 let token = provider.token().await?;
694 assert_eq!(token.token, "test-access-token");
695
696 server.verify_and_clear();
697 Ok(())
698 }
699
700 #[test]
701 fn validate_default_endpoint_urls() {
702 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
703 assert!(default_endpoint_address.is_ok());
704
705 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
706 assert!(token_endpoint_address.is_ok());
707 }
708
709 #[tokio::test]
710 async fn headers_success() -> TestResult {
711 let token = Token {
712 token: "test-token".to_string(),
713 token_type: "Bearer".to_string(),
714 expires_at: None,
715 metadata: None,
716 };
717
718 let mut mock = MockTokenProvider::new();
719 mock.expect_token().times(1).return_once(|| Ok(token));
720
721 let mdsc = MDSCredentials {
722 quota_project_id: None,
723 token_provider: TokenCache::new(mock),
724 };
725
726 let mut extensions = Extensions::new();
727 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
728 let (headers, entity_tag) = match cached_headers {
729 CacheableResource::New { entity_tag, data } => (data, entity_tag),
730 CacheableResource::NotModified => unreachable!("expecting new headers"),
731 };
732 let token = headers.get(AUTHORIZATION).unwrap();
733 assert_eq!(headers.len(), 1, "{headers:?}");
734 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
735 assert!(token.is_sensitive());
736
737 extensions.insert(entity_tag);
738
739 let cached_headers = mdsc.headers(extensions).await?;
740
741 match cached_headers {
742 CacheableResource::New { .. } => unreachable!("expecting new headers"),
743 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
744 };
745 Ok(())
746 }
747
748 #[tokio::test]
749 async fn headers_failure() {
750 let mut mock = MockTokenProvider::new();
751 mock.expect_token()
752 .times(1)
753 .return_once(|| Err(errors::non_retryable_from_str("fail")));
754
755 let mdsc = MDSCredentials {
756 quota_project_id: None,
757 token_provider: TokenCache::new(mock),
758 };
759 assert!(mdsc.headers(Extensions::new()).await.is_err());
760 }
761
762 #[test]
763 fn error_message_with_adc() {
764 let provider = MDSAccessTokenProvider::builder()
765 .endpoint("http://127.0.0.1")
766 .created_by_adc(true)
767 .endpoint_overridden(false)
768 .build();
769
770 let want = MDS_NOT_FOUND_ERROR;
771 let got = provider.error_message();
772 assert!(got.contains(want), "{got}, {provider:?}");
773 }
774
775 #[test_case(false, false)]
776 #[test_case(false, true)]
777 #[test_case(true, true)]
778 fn error_message_without_adc(adc: bool, overridden: bool) {
779 let provider = MDSAccessTokenProvider::builder()
780 .endpoint("http://127.0.0.1")
781 .created_by_adc(adc)
782 .endpoint_overridden(overridden)
783 .build();
784
785 let not_want = MDS_NOT_FOUND_ERROR;
786 let got = provider.error_message();
787 assert!(!got.contains(not_want), "{got}, {provider:?}");
788 }
789
790 #[tokio::test]
791 #[serial]
792 async fn adc_no_mds() -> TestResult {
793 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
794 return Ok(());
796 };
797
798 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
799 assert!(
800 original_err.to_string().contains("application-default"),
801 "display={err}, debug={err:?}"
802 );
803
804 Ok(())
805 }
806
807 #[tokio::test]
808 #[serial]
809 async fn adc_overridden_mds() -> TestResult {
810 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
811
812 let err = Builder::from_adc()
813 .build_token_provider()
814 .token()
815 .await
816 .unwrap_err();
817
818 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
819
820 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
821 assert!(original_err.is_transient());
822 assert!(
823 !original_err.to_string().contains("application-default"),
824 "display={err}, debug={err:?}"
825 );
826 let source = find_source_error::<reqwest::Error>(&err);
827 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
828
829 Ok(())
830 }
831
832 #[tokio::test]
833 #[serial]
834 async fn builder_no_mds() -> TestResult {
835 let Err(e) = Builder::default().build_token_provider().token().await else {
836 return Ok(());
838 };
839
840 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
841 assert!(
842 !format!("{:?}", original_err.source()).contains("application-default"),
843 "{e:?}"
844 );
845
846 Ok(())
847 }
848
849 #[tokio::test]
850 #[serial]
851 async fn test_gce_metadata_host_env_var() -> TestResult {
852 let server = Server::run();
853 let scopes = ["scope1", "scope2"];
854 let response = MDSTokenResponse {
855 access_token: "test-access-token".to_string(),
856 expires_in: Some(3600),
857 token_type: "test-token-type".to_string(),
858 };
859 server.expect(
860 Expectation::matching(all_of![
861 request::path(format!("{MDS_DEFAULT_URI}/token")),
862 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
863 ])
864 .respond_with(json_encoded(response)),
865 );
866
867 let addr = server.addr().to_string();
868 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
869 let mdsc = Builder::default()
870 .with_scopes(["scope1", "scope2"])
871 .build()
872 .unwrap();
873 let headers = mdsc.headers(Extensions::new()).await.unwrap();
874 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
875
876 assert_eq!(
877 get_token_from_headers(headers).unwrap(),
878 "test-access-token"
879 );
880 Ok(())
881 }
882
883 #[tokio::test]
884 #[parallel]
885 async fn headers_success_with_quota_project() -> TestResult {
886 let server = Server::run();
887 let scopes = ["scope1", "scope2"];
888 let response = MDSTokenResponse {
889 access_token: "test-access-token".to_string(),
890 expires_in: Some(3600),
891 token_type: "test-token-type".to_string(),
892 };
893 server.expect(
894 Expectation::matching(all_of![
895 request::path(format!("{MDS_DEFAULT_URI}/token")),
896 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
897 ])
898 .respond_with(json_encoded(response)),
899 );
900
901 let mdsc = Builder::default()
902 .with_scopes(["scope1", "scope2"])
903 .with_endpoint(format!("http://{}", server.addr()))
904 .with_quota_project_id("test-project")
905 .build()?;
906
907 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
908 let token = headers.get(AUTHORIZATION).unwrap();
909 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
910
911 assert_eq!(headers.len(), 2, "{headers:?}");
912 assert_eq!(
913 token,
914 HeaderValue::from_static("test-token-type test-access-token")
915 );
916 assert!(token.is_sensitive());
917 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
918 assert!(!quota_project.is_sensitive());
919
920 Ok(())
921 }
922
923 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
924 #[parallel]
925 async fn token_caching() -> TestResult {
926 let mut server = Server::run();
927 let scopes = vec!["scope1".to_string()];
928 let response = MDSTokenResponse {
929 access_token: "test-access-token".to_string(),
930 expires_in: Some(3600),
931 token_type: "test-token-type".to_string(),
932 };
933 server.expect(
934 Expectation::matching(all_of![
935 request::path(format!("{MDS_DEFAULT_URI}/token")),
936 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
937 ])
938 .times(1)
939 .respond_with(json_encoded(response)),
940 );
941
942 let mdsc = Builder::default()
943 .with_scopes(scopes)
944 .with_endpoint(format!("http://{}", server.addr()))
945 .build()?;
946 let headers = mdsc.headers(Extensions::new()).await?;
947 assert_eq!(
948 get_token_from_headers(headers).unwrap(),
949 "test-access-token"
950 );
951 let headers = mdsc.headers(Extensions::new()).await?;
952 assert_eq!(
953 get_token_from_headers(headers).unwrap(),
954 "test-access-token"
955 );
956
957 server.verify_and_clear();
959
960 Ok(())
961 }
962
963 #[tokio::test(start_paused = true)]
964 #[parallel]
965 async fn token_provider_full() -> TestResult {
966 let server = Server::run();
967 let scopes = vec!["scope1".to_string()];
968 let response = MDSTokenResponse {
969 access_token: "test-access-token".to_string(),
970 expires_in: Some(3600),
971 token_type: "test-token-type".to_string(),
972 };
973 server.expect(
974 Expectation::matching(all_of![
975 request::path(format!("{MDS_DEFAULT_URI}/token")),
976 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
977 ])
978 .respond_with(json_encoded(response)),
979 );
980
981 let token = Builder::default()
982 .with_endpoint(format!("http://{}", server.addr()))
983 .with_scopes(scopes)
984 .build_token_provider()
985 .token()
986 .await?;
987
988 let now = tokio::time::Instant::now();
989 assert_eq!(token.token, "test-access-token");
990 assert_eq!(token.token_type, "test-token-type");
991 assert!(
992 token
993 .expires_at
994 .is_some_and(|d| d >= now + Duration::from_secs(3600))
995 );
996
997 Ok(())
998 }
999
1000 #[tokio::test(start_paused = true)]
1001 #[parallel]
1002 async fn token_provider_full_no_scopes() -> TestResult {
1003 let server = Server::run();
1004 let response = MDSTokenResponse {
1005 access_token: "test-access-token".to_string(),
1006 expires_in: Some(3600),
1007 token_type: "test-token-type".to_string(),
1008 };
1009 server.expect(
1010 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
1011 .respond_with(json_encoded(response)),
1012 );
1013
1014 let token = Builder::default()
1015 .with_endpoint(format!("http://{}", server.addr()))
1016 .build_token_provider()
1017 .token()
1018 .await?;
1019
1020 let now = Instant::now();
1021 assert_eq!(token.token, "test-access-token");
1022 assert_eq!(token.token_type, "test-token-type");
1023 assert!(
1024 token
1025 .expires_at
1026 .is_some_and(|d| d == now + Duration::from_secs(3600))
1027 );
1028
1029 Ok(())
1030 }
1031
1032 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1033 #[parallel]
1034 async fn credential_provider_full() -> TestResult {
1035 let server = Server::run();
1036 let scopes = vec!["scope1".to_string()];
1037 let response = MDSTokenResponse {
1038 access_token: "test-access-token".to_string(),
1039 expires_in: None,
1040 token_type: "test-token-type".to_string(),
1041 };
1042 server.expect(
1043 Expectation::matching(all_of![
1044 request::path(format!("{MDS_DEFAULT_URI}/token")),
1045 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1046 ])
1047 .respond_with(json_encoded(response)),
1048 );
1049
1050 let mdsc = Builder::default()
1051 .with_endpoint(format!("http://{}", server.addr()))
1052 .with_scopes(scopes)
1053 .build()?;
1054 let headers = mdsc.headers(Extensions::new()).await?;
1055 assert_eq!(
1056 get_token_from_headers(headers.clone()).unwrap(),
1057 "test-access-token"
1058 );
1059 assert_eq!(
1060 get_token_type_from_headers(headers).unwrap(),
1061 "test-token-type"
1062 );
1063
1064 Ok(())
1065 }
1066
1067 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1068 #[parallel]
1069 async fn credentials_headers_retryable_error() -> TestResult {
1070 let server = Server::run();
1071 let scopes = vec!["scope1".to_string()];
1072 server.expect(
1073 Expectation::matching(all_of![
1074 request::path(format!("{MDS_DEFAULT_URI}/token")),
1075 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1076 ])
1077 .respond_with(status_code(503)),
1078 );
1079
1080 let mdsc = Builder::default()
1081 .with_endpoint(format!("http://{}", server.addr()))
1082 .with_scopes(scopes)
1083 .build()?;
1084 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1085 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1086 assert!(original_err.is_transient());
1087 let source = find_source_error::<reqwest::Error>(&err);
1088 assert!(
1089 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1090 "{err:?}"
1091 );
1092
1093 Ok(())
1094 }
1095
1096 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1097 #[parallel]
1098 async fn credentials_headers_nonretryable_error() -> TestResult {
1099 let server = Server::run();
1100 let scopes = vec!["scope1".to_string()];
1101 server.expect(
1102 Expectation::matching(all_of![
1103 request::path(format!("{MDS_DEFAULT_URI}/token")),
1104 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1105 ])
1106 .respond_with(status_code(401)),
1107 );
1108
1109 let mdsc = Builder::default()
1110 .with_endpoint(format!("http://{}", server.addr()))
1111 .with_scopes(scopes)
1112 .build()?;
1113
1114 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1115 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1116 assert!(!original_err.is_transient());
1117 let source = find_source_error::<reqwest::Error>(&err);
1118 assert!(
1119 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1120 "{err:?}"
1121 );
1122
1123 Ok(())
1124 }
1125
1126 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1127 #[parallel]
1128 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1129 let server = Server::run();
1130 let scopes = vec!["scope1".to_string()];
1131 server.expect(
1132 Expectation::matching(all_of![
1133 request::path(format!("{MDS_DEFAULT_URI}/token")),
1134 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1135 ])
1136 .respond_with(json_encoded("bad json")),
1137 );
1138
1139 let mdsc = Builder::default()
1140 .with_endpoint(format!("http://{}", server.addr()))
1141 .with_scopes(scopes)
1142 .build()?;
1143
1144 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1145 assert!(!e.is_transient());
1146
1147 Ok(())
1148 }
1149
1150 #[tokio::test]
1151 async fn get_default_universe_domain_success() -> TestResult {
1152 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1153 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1154 Ok(())
1155 }
1156}
1157
1158#[cfg(all(test, google_cloud_unstable_id_token))]
1159mod unstable_tests {
1160 use super::idtoken;
1161 use super::*;
1162 use crate::credentials::idtoken::tests::generate_test_id_token;
1163 use crate::credentials::tests::find_source_error;
1164 use httptest::matchers::{all_of, contains, request, url_decoded};
1165 use httptest::responders::status_code;
1166 use httptest::{Expectation, Server};
1167 use reqwest::StatusCode;
1168 use scoped_env::ScopedEnv;
1169 use serial_test::{parallel, serial};
1170
1171 type TestResult = anyhow::Result<()>;
1172
1173 #[tokio::test]
1174 #[parallel]
1175 async fn test_idtoken_builder_build() -> TestResult {
1176 let server = Server::run();
1177 let audience = "test-audience";
1178 let format = "format";
1179 let token_string = generate_test_id_token(audience);
1180 server.expect(
1181 Expectation::matching(all_of![
1182 request::path(format!("{MDS_DEFAULT_URI}/identity")),
1183 request::query(url_decoded(contains(("audience", audience)))),
1184 request::query(url_decoded(contains(("format", format)))),
1185 request::query(url_decoded(contains(("licenses", "TRUE"))))
1186 ])
1187 .respond_with(status_code(200).body(token_string.clone())),
1188 );
1189
1190 let creds = idtoken::Builder::new(audience)
1191 .with_endpoint(format!("http://{}", server.addr()))
1192 .with_format(format)
1193 .with_licenses(true)
1194 .build()?;
1195
1196 let id_token = creds.id_token().await?;
1197 assert_eq!(id_token, token_string);
1198 Ok(())
1199 }
1200
1201 #[tokio::test]
1202 #[serial]
1203 async fn test_idtoken_builder_build_with_env_var() -> TestResult {
1204 let server = Server::run();
1205 let audience = "test-audience";
1206 let token_string = generate_test_id_token(audience);
1207 server.expect(
1208 Expectation::matching(all_of![
1209 request::path(format!("{MDS_DEFAULT_URI}/identity")),
1210 request::query(url_decoded(contains(("audience", audience))))
1211 ])
1212 .respond_with(status_code(200).body(token_string.clone())),
1213 );
1214
1215 let addr = server.addr().to_string();
1216 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
1217
1218 let creds = idtoken::Builder::new(audience).build()?;
1219
1220 let id_token = creds.id_token().await?;
1221 assert_eq!(id_token, token_string);
1222
1223 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
1224 Ok(())
1225 }
1226
1227 #[tokio::test]
1228 #[parallel]
1229 async fn test_idtoken_provider_http_error() -> TestResult {
1230 let server = Server::run();
1231 let audience = "test-audience";
1232 server.expect(
1233 Expectation::matching(all_of![
1234 request::path(format!("{MDS_DEFAULT_URI}/identity")),
1235 request::query(url_decoded(contains(("audience", audience))))
1236 ])
1237 .respond_with(status_code(503)),
1238 );
1239
1240 let creds = idtoken::Builder::new(audience)
1241 .with_endpoint(format!("http://{}", server.addr()))
1242 .build()?;
1243
1244 let err = creds.id_token().await.unwrap_err();
1245 let source = find_source_error::<reqwest::Error>(&err);
1246 assert!(
1247 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1248 "{err:?}"
1249 );
1250 Ok(())
1251 }
1252
1253 #[tokio::test]
1254 #[parallel]
1255 async fn test_idtoken_caching() -> TestResult {
1256 let server = Server::run();
1257 let audience = "test-audience";
1258 let token_string = generate_test_id_token(audience);
1259 server.expect(
1260 Expectation::matching(all_of![
1261 request::path(format!("{MDS_DEFAULT_URI}/identity")),
1262 request::query(url_decoded(contains(("audience", audience))))
1263 ])
1264 .times(1)
1265 .respond_with(status_code(200).body(token_string.clone())),
1266 );
1267
1268 let creds = idtoken::Builder::new(audience)
1269 .with_endpoint(format!("http://{}", server.addr()))
1270 .build()?;
1271
1272 let id_token = creds.id_token().await?;
1273 assert_eq!(id_token, token_string);
1274
1275 let id_token = creds.id_token().await?;
1276 assert_eq!(id_token, token_string);
1277
1278 Ok(())
1279 }
1280}