1use crate::credentials::dynamic::CredentialsProvider;
78use crate::credentials::{CacheableResource, Credentials};
79use crate::errors::CredentialsError;
80use crate::headers_util::build_cacheable_headers;
81use crate::retry::{Builder as RetryTokenProviderBuilder, TokenProviderWithRetry};
82use crate::token::{CachedTokenProvider, Token, TokenProvider};
83use crate::token_cache::TokenCache;
84use crate::{BuildResult, Result};
85use async_trait::async_trait;
86use bon::Builder;
87use gax::backoff_policy::BackoffPolicyArg;
88use gax::retry_policy::RetryPolicyArg;
89use gax::retry_throttler::RetryThrottlerArg;
90use http::{Extensions, HeaderMap, HeaderValue};
91use reqwest::Client;
92use std::default::Default;
93use std::sync::Arc;
94use std::time::Duration;
95use tokio::time::Instant;
96
97const METADATA_FLAVOR_VALUE: &str = "Google";
98const METADATA_FLAVOR: &str = "metadata-flavor";
99const METADATA_ROOT: &str = "http://metadata.google.internal";
100const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
101const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
102const MDS_NOT_FOUND_ERROR: &str = concat!(
104 "Could not fetch an auth token to authenticate with Google Cloud. ",
105 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
106 "and you have not configured local credentials for development and testing. ",
107 "To setup local credentials, run `gcloud auth application-default login`. ",
108 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
109);
110
111#[derive(Debug)]
112struct MDSCredentials<T>
113where
114 T: CachedTokenProvider,
115{
116 quota_project_id: Option<String>,
117 token_provider: T,
118}
119
120#[derive(Debug, Default)]
132pub struct Builder {
133 endpoint: Option<String>,
134 quota_project_id: Option<String>,
135 scopes: Option<Vec<String>>,
136 created_by_adc: bool,
137 retry_builder: RetryTokenProviderBuilder,
138}
139
140impl Builder {
141 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
156 self.endpoint = Some(endpoint.into());
157 self
158 }
159
160 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
169 self.quota_project_id = Some(quota_project_id.into());
170 self
171 }
172
173 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
182 where
183 I: IntoIterator<Item = S>,
184 S: Into<String>,
185 {
186 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
187 self
188 }
189
190 pub fn with_retry_policy<V: Into<RetryPolicyArg>>(mut self, v: V) -> Self {
205 self.retry_builder = self.retry_builder.with_retry_policy(v.into());
206 self
207 }
208
209 pub fn with_backoff_policy<V: Into<BackoffPolicyArg>>(mut self, v: V) -> Self {
225 self.retry_builder = self.retry_builder.with_backoff_policy(v.into());
226 self
227 }
228
229 pub fn with_retry_throttler<V: Into<RetryThrottlerArg>>(mut self, v: V) -> Self {
250 self.retry_builder = self.retry_builder.with_retry_throttler(v.into());
251 self
252 }
253
254 pub(crate) fn from_adc() -> Self {
256 Self {
257 created_by_adc: true,
258 ..Default::default()
259 }
260 }
261
262 fn build_token_provider(self) -> TokenProviderWithRetry<MDSAccessTokenProvider> {
263 let final_endpoint: String;
264 let endpoint_overridden: bool;
265
266 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
268 final_endpoint = format!("http://{host_from_env}");
270 endpoint_overridden = true;
271 } else if let Some(builder_endpoint) = self.endpoint {
272 final_endpoint = builder_endpoint;
274 endpoint_overridden = true;
275 } else {
276 final_endpoint = METADATA_ROOT.to_string();
278 endpoint_overridden = false;
279 };
280
281 let tp = MDSAccessTokenProvider::builder()
282 .endpoint(final_endpoint)
283 .maybe_scopes(self.scopes)
284 .endpoint_overridden(endpoint_overridden)
285 .created_by_adc(self.created_by_adc)
286 .build();
287 self.retry_builder.build(tp)
288 }
289
290 pub fn build(self) -> BuildResult<Credentials> {
292 let mdsc = MDSCredentials {
293 quota_project_id: self.quota_project_id.clone(),
294 token_provider: TokenCache::new(self.build_token_provider()),
295 };
296 Ok(Credentials {
297 inner: Arc::new(mdsc),
298 })
299 }
300}
301
302#[async_trait::async_trait]
303impl<T> CredentialsProvider for MDSCredentials<T>
304where
305 T: CachedTokenProvider,
306{
307 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
308 let cached_token = self.token_provider.token(extensions).await?;
309 build_cacheable_headers(&cached_token, &self.quota_project_id)
310 }
311}
312
313#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
314struct MDSTokenResponse {
315 access_token: String,
316 #[serde(skip_serializing_if = "Option::is_none")]
317 expires_in: Option<u64>,
318 token_type: String,
319}
320
321#[derive(Debug, Clone, Default, Builder)]
322struct MDSAccessTokenProvider {
323 #[builder(into)]
324 scopes: Option<Vec<String>>,
325 #[builder(into)]
326 endpoint: String,
327 endpoint_overridden: bool,
328 created_by_adc: bool,
329}
330
331impl MDSAccessTokenProvider {
332 fn error_message(&self) -> &str {
340 if self.use_adc_message() {
341 MDS_NOT_FOUND_ERROR
342 } else {
343 "failed to fetch token"
344 }
345 }
346
347 fn use_adc_message(&self) -> bool {
348 self.created_by_adc && !self.endpoint_overridden
349 }
350}
351
352#[async_trait]
353impl TokenProvider for MDSAccessTokenProvider {
354 async fn token(&self) -> Result<Token> {
355 let client = Client::new();
356 let request = client
357 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
358 .header(
359 METADATA_FLAVOR,
360 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
361 );
362 let scopes = self.scopes.as_ref().map(|v| v.join(","));
365 let request = scopes
366 .into_iter()
367 .fold(request, |r, s| r.query(&[("scopes", s)]));
368
369 let response = request
374 .send()
375 .await
376 .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
377 if !response.status().is_success() {
379 let err = crate::errors::from_http_response(response, self.error_message()).await;
380 return Err(err);
381 }
382 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
383 CredentialsError::from_source(!e.is_decode(), e)
387 })?;
388 let token = Token {
389 token: response.access_token,
390 token_type: response.token_type,
391 expires_at: response
392 .expires_in
393 .map(|d| Instant::now() + Duration::from_secs(d)),
394 metadata: None,
395 };
396 Ok(token)
397 }
398}
399
400#[allow(dead_code)]
401pub(crate) mod idtoken {
402 use std::sync::Arc;
404
405 use super::{
406 GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_FLAVOR, METADATA_FLAVOR_VALUE,
407 METADATA_ROOT,
408 };
409 use crate::Result;
410 use crate::errors::CredentialsError;
411 use crate::token::{Token, TokenProvider};
412 use crate::{
413 BuildResult,
414 credentials::idtoken::{IDTokenCredentials, dynamic::IDTokenCredentialsProvider},
415 };
416 use async_trait::async_trait;
417 use http::HeaderValue;
418 use reqwest::Client;
419
420 #[derive(Debug)]
421 pub(crate) struct MDSCredentials<T>
422 where
423 T: TokenProvider,
424 {
425 token_provider: T,
426 }
427
428 #[async_trait]
429 impl<T> IDTokenCredentialsProvider for MDSCredentials<T>
430 where
431 T: TokenProvider,
432 {
433 async fn id_token(&self) -> Result<Token> {
434 self.token_provider.token().await
435 }
436 }
437
438 #[derive(Debug, Default)]
441 pub struct Builder {
442 endpoint: Option<String>,
443 format: Option<String>,
444 licenses: Option<String>,
445 target_audience: String,
446 }
447
448 impl Builder {
449 pub fn new<S: Into<String>>(target_audience: S) -> Self {
455 Builder {
456 format: None,
457 endpoint: None,
458 licenses: None,
459 target_audience: target_audience.into(),
460 }
461 }
462
463 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
468 self.endpoint = Some(endpoint.into());
469 self
470 }
471
472 pub fn with_format<S: Into<String>>(mut self, format: S) -> Self {
480 self.format = Some(format.into());
481 self
482 }
483
484 pub fn with_licenses(mut self, licenses: bool) -> Self {
491 self.licenses = if licenses {
492 Some("TRUE".to_string())
493 } else {
494 Some("FALSE".to_string())
495 };
496 self
497 }
498
499 fn build_token_provider(self) -> MDSTokenProvider {
500 let final_endpoint: String;
501
502 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
504 final_endpoint = format!("http://{host_from_env}");
506 } else if let Some(builder_endpoint) = self.endpoint {
507 final_endpoint = builder_endpoint;
509 } else {
510 final_endpoint = METADATA_ROOT.to_string();
512 };
513
514 MDSTokenProvider {
515 format: self.format,
516 licenses: self.licenses,
517 endpoint: final_endpoint,
518 target_audience: self.target_audience,
519 }
520 }
521
522 pub fn build(self) -> BuildResult<IDTokenCredentials> {
525 let creds = MDSCredentials {
526 token_provider: self.build_token_provider(),
527 };
528 Ok(IDTokenCredentials {
529 inner: Arc::new(creds),
530 })
531 }
532 }
533
534 #[derive(Debug, Clone, Default)]
535 struct MDSTokenProvider {
536 endpoint: String,
537 format: Option<String>,
538 licenses: Option<String>,
539 target_audience: String,
540 }
541
542 #[async_trait]
543 impl TokenProvider for MDSTokenProvider {
544 async fn token(&self) -> Result<Token> {
545 let client = Client::new();
546 let audience = self.target_audience.clone();
547 let request = client
548 .get(format!("{}{}/identity", self.endpoint, MDS_DEFAULT_URI))
549 .header(
550 METADATA_FLAVOR,
551 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
552 )
553 .query(&[("audience", audience)]);
554 let request = self.format.iter().fold(request, |builder, format| {
555 builder.query(&[("format", format)])
556 });
557 let request = self.licenses.iter().fold(request, |builder, licenses| {
558 builder.query(&[("licenses", licenses)])
559 });
560
561 let response = request
562 .send()
563 .await
564 .map_err(|e| crate::errors::from_http_error(e, "failed to fetch token"))?;
565
566 if !response.status().is_success() {
567 let err =
568 crate::errors::from_http_response(response, "failed to fetch token").await;
569 return Err(err);
570 }
571
572 let token = response
573 .text()
574 .await
575 .map_err(|e| CredentialsError::from_source(!e.is_decode(), e))?;
576
577 Ok(Token {
578 token,
579 token_type: "Bearer".to_string(),
580 expires_at: None,
582 metadata: None,
583 })
584 }
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::idtoken;
591 use super::*;
592 use crate::credentials::DEFAULT_UNIVERSE_DOMAIN;
593 use crate::credentials::QUOTA_PROJECT_KEY;
594 use crate::credentials::tests::{
595 find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
596 get_mock_backoff_policy, get_mock_retry_throttler, get_token_from_headers,
597 get_token_type_from_headers,
598 };
599 use crate::errors;
600 use crate::errors::CredentialsError;
601 use crate::token::tests::MockTokenProvider;
602 use http::HeaderValue;
603 use http::header::AUTHORIZATION;
604 use httptest::cycle;
605 use httptest::matchers::{all_of, contains, request, url_decoded};
606 use httptest::responders::{json_encoded, status_code};
607 use httptest::{Expectation, Server};
608 use reqwest::StatusCode;
609 use scoped_env::ScopedEnv;
610 use serial_test::{parallel, serial};
611 use std::error::Error;
612 use test_case::test_case;
613 use url::Url;
614
615 type TestResult = anyhow::Result<()>;
616
617 #[tokio::test]
618 #[parallel]
619 async fn test_mds_retries_on_transient_failures() -> TestResult {
620 let mut server = Server::run();
621 server.expect(
622 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
623 .times(3)
624 .respond_with(status_code(503)),
625 );
626
627 let provider = Builder::default()
628 .with_endpoint(format!("http://{}", server.addr()))
629 .with_retry_policy(get_mock_auth_retry_policy(3))
630 .with_backoff_policy(get_mock_backoff_policy())
631 .with_retry_throttler(get_mock_retry_throttler())
632 .build_token_provider();
633
634 let err = provider.token().await.unwrap_err();
635 assert!(!err.is_transient());
636 server.verify_and_clear();
637 Ok(())
638 }
639
640 #[tokio::test]
641 #[parallel]
642 async fn test_mds_does_not_retry_on_non_transient_failures() -> TestResult {
643 let mut server = Server::run();
644 server.expect(
645 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
646 .times(1)
647 .respond_with(status_code(401)),
648 );
649
650 let provider = Builder::default()
651 .with_endpoint(format!("http://{}", server.addr()))
652 .with_retry_policy(get_mock_auth_retry_policy(1))
653 .with_backoff_policy(get_mock_backoff_policy())
654 .with_retry_throttler(get_mock_retry_throttler())
655 .build_token_provider();
656
657 let err = provider.token().await.unwrap_err();
658 assert!(!err.is_transient());
659 server.verify_and_clear();
660 Ok(())
661 }
662
663 #[tokio::test]
664 #[parallel]
665 async fn test_mds_retries_for_success() -> TestResult {
666 let mut server = Server::run();
667 let response = MDSTokenResponse {
668 access_token: "test-access-token".to_string(),
669 expires_in: Some(3600),
670 token_type: "test-token-type".to_string(),
671 };
672
673 server.expect(
674 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
675 .times(3)
676 .respond_with(cycle![
677 status_code(503).body("try-again"),
678 status_code(503).body("try-again"),
679 status_code(200)
680 .append_header("Content-Type", "application/json")
681 .body(serde_json::to_string(&response).unwrap()),
682 ]),
683 );
684
685 let provider = Builder::default()
686 .with_endpoint(format!("http://{}", server.addr()))
687 .with_retry_policy(get_mock_auth_retry_policy(3))
688 .with_backoff_policy(get_mock_backoff_policy())
689 .with_retry_throttler(get_mock_retry_throttler())
690 .build_token_provider();
691
692 let token = provider.token().await?;
693 assert_eq!(token.token, "test-access-token");
694
695 server.verify_and_clear();
696 Ok(())
697 }
698
699 #[test]
700 fn validate_default_endpoint_urls() {
701 let default_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}"));
702 assert!(default_endpoint_address.is_ok());
703
704 let token_endpoint_address = Url::parse(&format!("{METADATA_ROOT}{MDS_DEFAULT_URI}/token"));
705 assert!(token_endpoint_address.is_ok());
706 }
707
708 #[tokio::test]
709 async fn headers_success() -> TestResult {
710 let token = Token {
711 token: "test-token".to_string(),
712 token_type: "Bearer".to_string(),
713 expires_at: None,
714 metadata: None,
715 };
716
717 let mut mock = MockTokenProvider::new();
718 mock.expect_token().times(1).return_once(|| Ok(token));
719
720 let mdsc = MDSCredentials {
721 quota_project_id: None,
722 token_provider: TokenCache::new(mock),
723 };
724
725 let mut extensions = Extensions::new();
726 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
727 let (headers, entity_tag) = match cached_headers {
728 CacheableResource::New { entity_tag, data } => (data, entity_tag),
729 CacheableResource::NotModified => unreachable!("expecting new headers"),
730 };
731 let token = headers.get(AUTHORIZATION).unwrap();
732 assert_eq!(headers.len(), 1, "{headers:?}");
733 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
734 assert!(token.is_sensitive());
735
736 extensions.insert(entity_tag);
737
738 let cached_headers = mdsc.headers(extensions).await?;
739
740 match cached_headers {
741 CacheableResource::New { .. } => unreachable!("expecting new headers"),
742 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
743 };
744 Ok(())
745 }
746
747 #[tokio::test]
748 async fn headers_failure() {
749 let mut mock = MockTokenProvider::new();
750 mock.expect_token()
751 .times(1)
752 .return_once(|| Err(errors::non_retryable_from_str("fail")));
753
754 let mdsc = MDSCredentials {
755 quota_project_id: None,
756 token_provider: TokenCache::new(mock),
757 };
758 assert!(mdsc.headers(Extensions::new()).await.is_err());
759 }
760
761 #[test]
762 fn error_message_with_adc() {
763 let provider = MDSAccessTokenProvider::builder()
764 .endpoint("http://127.0.0.1")
765 .created_by_adc(true)
766 .endpoint_overridden(false)
767 .build();
768
769 let want = MDS_NOT_FOUND_ERROR;
770 let got = provider.error_message();
771 assert!(got.contains(want), "{got}, {provider:?}");
772 }
773
774 #[test_case(false, false)]
775 #[test_case(false, true)]
776 #[test_case(true, true)]
777 fn error_message_without_adc(adc: bool, overridden: bool) {
778 let provider = MDSAccessTokenProvider::builder()
779 .endpoint("http://127.0.0.1")
780 .created_by_adc(adc)
781 .endpoint_overridden(overridden)
782 .build();
783
784 let not_want = MDS_NOT_FOUND_ERROR;
785 let got = provider.error_message();
786 assert!(!got.contains(not_want), "{got}, {provider:?}");
787 }
788
789 #[tokio::test]
790 #[serial]
791 async fn adc_no_mds() -> TestResult {
792 let Err(err) = Builder::from_adc().build_token_provider().token().await else {
793 return Ok(());
795 };
796
797 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
798 assert!(
799 original_err.to_string().contains("application-default"),
800 "display={err}, debug={err:?}"
801 );
802
803 Ok(())
804 }
805
806 #[tokio::test]
807 #[serial]
808 async fn adc_overridden_mds() -> TestResult {
809 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
810
811 let err = Builder::from_adc()
812 .build_token_provider()
813 .token()
814 .await
815 .unwrap_err();
816
817 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
818
819 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
820 assert!(original_err.is_transient());
821 assert!(
822 !original_err.to_string().contains("application-default"),
823 "display={err}, debug={err:?}"
824 );
825 let source = find_source_error::<reqwest::Error>(&err);
826 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
827
828 Ok(())
829 }
830
831 #[tokio::test]
832 #[serial]
833 async fn builder_no_mds() -> TestResult {
834 let Err(e) = Builder::default().build_token_provider().token().await else {
835 return Ok(());
837 };
838
839 let original_err = find_source_error::<CredentialsError>(&e).unwrap();
840 assert!(
841 !format!("{:?}", original_err.source()).contains("application-default"),
842 "{e:?}"
843 );
844
845 Ok(())
846 }
847
848 #[tokio::test]
849 #[serial]
850 async fn test_gce_metadata_host_env_var() -> TestResult {
851 let server = Server::run();
852 let scopes = ["scope1", "scope2"];
853 let response = MDSTokenResponse {
854 access_token: "test-access-token".to_string(),
855 expires_in: Some(3600),
856 token_type: "test-token-type".to_string(),
857 };
858 server.expect(
859 Expectation::matching(all_of![
860 request::path(format!("{MDS_DEFAULT_URI}/token")),
861 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
862 ])
863 .respond_with(json_encoded(response)),
864 );
865
866 let addr = server.addr().to_string();
867 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
868 let mdsc = Builder::default()
869 .with_scopes(["scope1", "scope2"])
870 .build()
871 .unwrap();
872 let headers = mdsc.headers(Extensions::new()).await.unwrap();
873 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
874
875 assert_eq!(
876 get_token_from_headers(headers).unwrap(),
877 "test-access-token"
878 );
879 Ok(())
880 }
881
882 #[tokio::test]
883 #[parallel]
884 async fn headers_success_with_quota_project() -> TestResult {
885 let server = Server::run();
886 let scopes = ["scope1", "scope2"];
887 let response = MDSTokenResponse {
888 access_token: "test-access-token".to_string(),
889 expires_in: Some(3600),
890 token_type: "test-token-type".to_string(),
891 };
892 server.expect(
893 Expectation::matching(all_of![
894 request::path(format!("{MDS_DEFAULT_URI}/token")),
895 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
896 ])
897 .respond_with(json_encoded(response)),
898 );
899
900 let mdsc = Builder::default()
901 .with_scopes(["scope1", "scope2"])
902 .with_endpoint(format!("http://{}", server.addr()))
903 .with_quota_project_id("test-project")
904 .build()?;
905
906 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
907 let token = headers.get(AUTHORIZATION).unwrap();
908 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
909
910 assert_eq!(headers.len(), 2, "{headers:?}");
911 assert_eq!(
912 token,
913 HeaderValue::from_static("test-token-type test-access-token")
914 );
915 assert!(token.is_sensitive());
916 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
917 assert!(!quota_project.is_sensitive());
918
919 Ok(())
920 }
921
922 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
923 #[parallel]
924 async fn token_caching() -> TestResult {
925 let mut server = Server::run();
926 let scopes = vec!["scope1".to_string()];
927 let response = MDSTokenResponse {
928 access_token: "test-access-token".to_string(),
929 expires_in: Some(3600),
930 token_type: "test-token-type".to_string(),
931 };
932 server.expect(
933 Expectation::matching(all_of![
934 request::path(format!("{MDS_DEFAULT_URI}/token")),
935 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
936 ])
937 .times(1)
938 .respond_with(json_encoded(response)),
939 );
940
941 let mdsc = Builder::default()
942 .with_scopes(scopes)
943 .with_endpoint(format!("http://{}", server.addr()))
944 .build()?;
945 let headers = mdsc.headers(Extensions::new()).await?;
946 assert_eq!(
947 get_token_from_headers(headers).unwrap(),
948 "test-access-token"
949 );
950 let headers = mdsc.headers(Extensions::new()).await?;
951 assert_eq!(
952 get_token_from_headers(headers).unwrap(),
953 "test-access-token"
954 );
955
956 server.verify_and_clear();
958
959 Ok(())
960 }
961
962 #[tokio::test(start_paused = true)]
963 #[parallel]
964 async fn token_provider_full() -> TestResult {
965 let server = Server::run();
966 let scopes = vec!["scope1".to_string()];
967 let response = MDSTokenResponse {
968 access_token: "test-access-token".to_string(),
969 expires_in: Some(3600),
970 token_type: "test-token-type".to_string(),
971 };
972 server.expect(
973 Expectation::matching(all_of![
974 request::path(format!("{MDS_DEFAULT_URI}/token")),
975 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
976 ])
977 .respond_with(json_encoded(response)),
978 );
979
980 let token = Builder::default()
981 .with_endpoint(format!("http://{}", server.addr()))
982 .with_scopes(scopes)
983 .build_token_provider()
984 .token()
985 .await?;
986
987 let now = tokio::time::Instant::now();
988 assert_eq!(token.token, "test-access-token");
989 assert_eq!(token.token_type, "test-token-type");
990 assert!(
991 token
992 .expires_at
993 .is_some_and(|d| d >= now + Duration::from_secs(3600))
994 );
995
996 Ok(())
997 }
998
999 #[tokio::test(start_paused = true)]
1000 #[parallel]
1001 async fn token_provider_full_no_scopes() -> TestResult {
1002 let server = Server::run();
1003 let response = MDSTokenResponse {
1004 access_token: "test-access-token".to_string(),
1005 expires_in: Some(3600),
1006 token_type: "test-token-type".to_string(),
1007 };
1008 server.expect(
1009 Expectation::matching(request::path(format!("{MDS_DEFAULT_URI}/token")))
1010 .respond_with(json_encoded(response)),
1011 );
1012
1013 let token = Builder::default()
1014 .with_endpoint(format!("http://{}", server.addr()))
1015 .build_token_provider()
1016 .token()
1017 .await?;
1018
1019 let now = Instant::now();
1020 assert_eq!(token.token, "test-access-token");
1021 assert_eq!(token.token_type, "test-token-type");
1022 assert!(
1023 token
1024 .expires_at
1025 .is_some_and(|d| d == now + Duration::from_secs(3600))
1026 );
1027
1028 Ok(())
1029 }
1030
1031 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1032 #[parallel]
1033 async fn credential_provider_full() -> TestResult {
1034 let server = Server::run();
1035 let scopes = vec!["scope1".to_string()];
1036 let response = MDSTokenResponse {
1037 access_token: "test-access-token".to_string(),
1038 expires_in: None,
1039 token_type: "test-token-type".to_string(),
1040 };
1041 server.expect(
1042 Expectation::matching(all_of![
1043 request::path(format!("{MDS_DEFAULT_URI}/token")),
1044 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1045 ])
1046 .respond_with(json_encoded(response)),
1047 );
1048
1049 let mdsc = Builder::default()
1050 .with_endpoint(format!("http://{}", server.addr()))
1051 .with_scopes(scopes)
1052 .build()?;
1053 let headers = mdsc.headers(Extensions::new()).await?;
1054 assert_eq!(
1055 get_token_from_headers(headers.clone()).unwrap(),
1056 "test-access-token"
1057 );
1058 assert_eq!(
1059 get_token_type_from_headers(headers).unwrap(),
1060 "test-token-type"
1061 );
1062
1063 Ok(())
1064 }
1065
1066 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1067 #[parallel]
1068 async fn credentials_headers_retryable_error() -> TestResult {
1069 let server = Server::run();
1070 let scopes = vec!["scope1".to_string()];
1071 server.expect(
1072 Expectation::matching(all_of![
1073 request::path(format!("{MDS_DEFAULT_URI}/token")),
1074 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1075 ])
1076 .respond_with(status_code(503)),
1077 );
1078
1079 let mdsc = Builder::default()
1080 .with_endpoint(format!("http://{}", server.addr()))
1081 .with_scopes(scopes)
1082 .build()?;
1083 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1084 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1085 assert!(original_err.is_transient());
1086 let source = find_source_error::<reqwest::Error>(&err);
1087 assert!(
1088 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1089 "{err:?}"
1090 );
1091
1092 Ok(())
1093 }
1094
1095 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1096 #[parallel]
1097 async fn credentials_headers_nonretryable_error() -> TestResult {
1098 let server = Server::run();
1099 let scopes = vec!["scope1".to_string()];
1100 server.expect(
1101 Expectation::matching(all_of![
1102 request::path(format!("{MDS_DEFAULT_URI}/token")),
1103 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1104 ])
1105 .respond_with(status_code(401)),
1106 );
1107
1108 let mdsc = Builder::default()
1109 .with_endpoint(format!("http://{}", server.addr()))
1110 .with_scopes(scopes)
1111 .build()?;
1112
1113 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
1114 let original_err = find_source_error::<CredentialsError>(&err).unwrap();
1115 assert!(!original_err.is_transient());
1116 let source = find_source_error::<reqwest::Error>(&err);
1117 assert!(
1118 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
1119 "{err:?}"
1120 );
1121
1122 Ok(())
1123 }
1124
1125 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1126 #[parallel]
1127 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
1128 let server = Server::run();
1129 let scopes = vec!["scope1".to_string()];
1130 server.expect(
1131 Expectation::matching(all_of![
1132 request::path(format!("{MDS_DEFAULT_URI}/token")),
1133 request::query(url_decoded(contains(("scopes", scopes.join(",")))))
1134 ])
1135 .respond_with(json_encoded("bad json")),
1136 );
1137
1138 let mdsc = Builder::default()
1139 .with_endpoint(format!("http://{}", server.addr()))
1140 .with_scopes(scopes)
1141 .build()?;
1142
1143 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
1144 assert!(!e.is_transient());
1145
1146 Ok(())
1147 }
1148
1149 #[tokio::test]
1150 async fn get_default_universe_domain_success() -> TestResult {
1151 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
1152 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
1153 Ok(())
1154 }
1155
1156 #[tokio::test]
1157 #[parallel]
1158 async fn test_idtoken_builder_build() -> TestResult {
1159 let server = Server::run();
1160 let audience = "test-audience";
1161 let format = "format";
1162 let token_string = "test-id-token";
1163 server.expect(
1164 Expectation::matching(all_of![
1165 request::path(format!("{MDS_DEFAULT_URI}/identity")),
1166 request::query(url_decoded(contains(("audience", audience)))),
1167 request::query(url_decoded(contains(("format", format)))),
1168 request::query(url_decoded(contains(("licenses", "TRUE"))))
1169 ])
1170 .respond_with(status_code(200).body(token_string)),
1171 );
1172
1173 let creds = idtoken::Builder::new(audience)
1174 .with_endpoint(format!("http://{}", server.addr()))
1175 .with_format(format)
1176 .with_licenses(true)
1177 .build()?;
1178
1179 let token = creds.id_token().await?;
1180 assert_eq!(token.token, token_string);
1181 assert_eq!(token.token_type, "Bearer");
1182 assert!(token.expires_at.is_none());
1183 Ok(())
1184 }
1185
1186 #[tokio::test]
1187 #[serial]
1188 async fn test_idtoken_builder_build_with_env_var() -> TestResult {
1189 let server = Server::run();
1190 let audience = "test-audience";
1191 let token_string = "test-id-token";
1192 server.expect(
1193 Expectation::matching(all_of![
1194 request::path(format!("{MDS_DEFAULT_URI}/identity")),
1195 request::query(url_decoded(contains(("audience", audience))))
1196 ])
1197 .respond_with(status_code(200).body(token_string)),
1198 );
1199
1200 let addr = server.addr().to_string();
1201 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, &addr);
1202
1203 let creds = idtoken::Builder::new(audience).build()?;
1204
1205 let token = creds.id_token().await?;
1206 assert_eq!(token.token, token_string);
1207
1208 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
1209 Ok(())
1210 }
1211
1212 #[tokio::test]
1213 #[parallel]
1214 async fn test_idtoken_provider_http_error() -> TestResult {
1215 let server = Server::run();
1216 let audience = "test-audience";
1217 server.expect(
1218 Expectation::matching(all_of![
1219 request::path(format!("{MDS_DEFAULT_URI}/identity")),
1220 request::query(url_decoded(contains(("audience", audience))))
1221 ])
1222 .respond_with(status_code(503)),
1223 );
1224
1225 let creds = idtoken::Builder::new(audience)
1226 .with_endpoint(format!("http://{}", server.addr()))
1227 .build()?;
1228
1229 let err = creds.id_token().await.unwrap_err();
1230 let source = find_source_error::<reqwest::Error>(&err);
1231 assert!(
1232 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
1233 "{err:?}"
1234 );
1235 Ok(())
1236 }
1237}