1use crate::credentials::dynamic::CredentialsProvider;
59use crate::credentials::{CacheableResource, Credentials, DEFAULT_UNIVERSE_DOMAIN};
60use crate::errors::CredentialsError;
61use crate::headers_util::build_cacheable_headers;
62use crate::token::{CachedTokenProvider, Token, TokenProvider};
63use crate::token_cache::TokenCache;
64use crate::{BuildResult, Result};
65use async_trait::async_trait;
66use bon::Builder;
67use http::{Extensions, HeaderMap, HeaderValue};
68use reqwest::Client;
69use std::default::Default;
70use std::sync::Arc;
71use std::time::Duration;
72use tokio::time::Instant;
73
74const METADATA_FLAVOR_VALUE: &str = "Google";
75const METADATA_FLAVOR: &str = "metadata-flavor";
76const METADATA_ROOT: &str = "http://metadata.google.internal";
77const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
78const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
79const MDS_NOT_FOUND_ERROR: &str = concat!(
81 "Could not fetch an auth token to authenticate with Google Cloud. ",
82 "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
83 "and you have not configured local credentials for development and testing. ",
84 "To setup local credentials, run `gcloud auth application-default login`. ",
85 "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
86);
87
88#[derive(Debug)]
89struct MDSCredentials<T>
90where
91 T: CachedTokenProvider,
92{
93 quota_project_id: Option<String>,
94 universe_domain: Option<String>,
95 token_provider: T,
96}
97
98#[derive(Debug, Default)]
110pub struct Builder {
111 endpoint: Option<String>,
112 quota_project_id: Option<String>,
113 scopes: Option<Vec<String>>,
114 universe_domain: Option<String>,
115 created_by_adc: bool,
116}
117
118impl Builder {
119 pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
134 self.endpoint = Some(endpoint.into());
135 self
136 }
137
138 pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
147 self.quota_project_id = Some(quota_project_id.into());
148 self
149 }
150
151 pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
158 self.universe_domain = Some(universe_domain.into());
159 self
160 }
161
162 pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
171 where
172 I: IntoIterator<Item = S>,
173 S: Into<String>,
174 {
175 self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
176 self
177 }
178
179 pub(crate) fn from_adc() -> Self {
181 Self {
182 created_by_adc: true,
183 ..Default::default()
184 }
185 }
186
187 fn build_token_provider(self) -> MDSAccessTokenProvider {
188 let final_endpoint: String;
189 let endpoint_overridden: bool;
190
191 if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
193 final_endpoint = format!("http://{}", host_from_env);
195 endpoint_overridden = true;
196 } else if let Some(builder_endpoint) = self.endpoint {
197 final_endpoint = builder_endpoint;
199 endpoint_overridden = true;
200 } else {
201 final_endpoint = METADATA_ROOT.to_string();
203 endpoint_overridden = false;
204 };
205
206 MDSAccessTokenProvider::builder()
207 .endpoint(final_endpoint)
208 .maybe_scopes(self.scopes)
209 .endpoint_overridden(endpoint_overridden)
210 .created_by_adc(self.created_by_adc)
211 .build()
212 }
213
214 pub fn build(self) -> BuildResult<Credentials> {
216 let mdsc = MDSCredentials {
217 quota_project_id: self.quota_project_id.clone(),
218 universe_domain: self.universe_domain.clone(),
219 token_provider: TokenCache::new(self.build_token_provider()),
220 };
221 Ok(Credentials {
222 inner: Arc::new(mdsc),
223 })
224 }
225}
226
227#[async_trait::async_trait]
228impl<T> CredentialsProvider for MDSCredentials<T>
229where
230 T: CachedTokenProvider,
231{
232 async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
233 let cached_token = self.token_provider.token(extensions).await?;
234 build_cacheable_headers(&cached_token, &self.quota_project_id)
235 }
236
237 async fn universe_domain(&self) -> Option<String> {
238 if self.universe_domain.is_some() {
239 return self.universe_domain.clone();
240 }
241 return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
242 }
243}
244
245#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
246struct ServiceAccountInfo {
247 email: String,
248 scopes: Option<Vec<String>>,
249 aliases: Option<Vec<String>>,
250}
251
252#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
253struct MDSTokenResponse {
254 access_token: String,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 expires_in: Option<u64>,
257 token_type: String,
258}
259
260#[derive(Debug, Clone, Default, Builder)]
261struct MDSAccessTokenProvider {
262 #[builder(into)]
263 scopes: Option<Vec<String>>,
264 #[builder(into)]
265 endpoint: String,
266 endpoint_overridden: bool,
267 created_by_adc: bool,
268}
269
270impl MDSAccessTokenProvider {
271 fn error_message(&self) -> &str {
279 if self.use_adc_message() {
280 MDS_NOT_FOUND_ERROR
281 } else {
282 "failed to fetch token"
283 }
284 }
285
286 fn use_adc_message(&self) -> bool {
287 self.created_by_adc && !self.endpoint_overridden
288 }
289}
290
291#[async_trait]
292impl TokenProvider for MDSAccessTokenProvider {
293 async fn token(&self) -> Result<Token> {
294 let client = Client::new();
295 let request = client
296 .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
297 .header(
298 METADATA_FLAVOR,
299 HeaderValue::from_static(METADATA_FLAVOR_VALUE),
300 );
301 let scopes = self.scopes.as_ref().map(|v| v.join(","));
304 let request = scopes
305 .into_iter()
306 .fold(request, |r, s| r.query(&[("scopes", s)]));
307
308 let response = request
313 .send()
314 .await
315 .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
316 if !response.status().is_success() {
318 let err = crate::errors::from_http_response(response, self.error_message()).await;
319 return Err(err);
320 }
321 let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
322 CredentialsError::from_source(!e.is_decode(), e)
326 })?;
327 let token = Token {
328 token: response.access_token,
329 token_type: response.token_type,
330 expires_at: response
331 .expires_in
332 .map(|d| Instant::now() + Duration::from_secs(d)),
333 metadata: None,
334 };
335 Ok(token)
336 }
337}
338
339#[cfg(test)]
340mod test {
341 use super::*;
342 use crate::credentials::QUOTA_PROJECT_KEY;
343 use crate::credentials::test::{
344 get_headers_from_cache, get_token_from_headers, get_token_type_from_headers,
345 };
346 use crate::errors;
347 use crate::token::test::MockTokenProvider;
348 use axum::extract::Query;
349 use axum::response::IntoResponse;
350 use http::header::AUTHORIZATION;
351 use reqwest::StatusCode;
352 use reqwest::header::HeaderMap;
353 use scoped_env::ScopedEnv;
354 use serde::Deserialize;
355 use serde_json::Value;
356 use serial_test::{parallel, serial};
357 use std::collections::HashMap;
358 use std::error::Error;
359 use std::sync::Mutex;
360 use test_case::test_case;
361 use tokio::task::JoinHandle;
362 use url::Url;
363
364 type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
365
366 #[derive(Debug, Clone, Deserialize, PartialEq)]
368 struct TokenQueryParams {
369 scopes: Option<String>,
370 recursive: Option<String>,
371 }
372
373 #[test]
374 fn validate_default_endpoint_urls() {
375 let default_endpoint_address = Url::parse(&format!("{}{}", METADATA_ROOT, MDS_DEFAULT_URI));
376 assert!(default_endpoint_address.is_ok());
377
378 let token_endpoint_address =
379 Url::parse(&format!("{}{}/token", METADATA_ROOT, MDS_DEFAULT_URI));
380 assert!(token_endpoint_address.is_ok());
381 }
382
383 #[tokio::test]
384 async fn headers_success() -> TestResult {
385 let token = Token {
386 token: "test-token".to_string(),
387 token_type: "Bearer".to_string(),
388 expires_at: None,
389 metadata: None,
390 };
391
392 let mut mock = MockTokenProvider::new();
393 mock.expect_token().times(1).return_once(|| Ok(token));
394
395 let mdsc = MDSCredentials {
396 quota_project_id: None,
397 universe_domain: None,
398 token_provider: TokenCache::new(mock),
399 };
400
401 let mut extensions = Extensions::new();
402 let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
403 let (headers, entity_tag) = match cached_headers {
404 CacheableResource::New { entity_tag, data } => (data, entity_tag),
405 CacheableResource::NotModified => unreachable!("expecting new headers"),
406 };
407 let token = headers.get(AUTHORIZATION).unwrap();
408 assert_eq!(headers.len(), 1, "{headers:?}");
409 assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
410 assert!(token.is_sensitive());
411
412 extensions.insert(entity_tag);
413
414 let cached_headers = mdsc.headers(extensions).await?;
415
416 match cached_headers {
417 CacheableResource::New { .. } => unreachable!("expecting new headers"),
418 CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
419 };
420 Ok(())
421 }
422
423 #[tokio::test]
424 async fn headers_failure() {
425 let mut mock = MockTokenProvider::new();
426 mock.expect_token()
427 .times(1)
428 .return_once(|| Err(errors::non_retryable_from_str("fail")));
429
430 let mdsc = MDSCredentials {
431 quota_project_id: None,
432 universe_domain: None,
433 token_provider: TokenCache::new(mock),
434 };
435 assert!(mdsc.headers(Extensions::new()).await.is_err());
436 }
437
438 #[test]
439 fn error_message_with_adc() {
440 let provider = MDSAccessTokenProvider::builder()
441 .endpoint("http://127.0.0.1")
442 .created_by_adc(true)
443 .endpoint_overridden(false)
444 .build();
445
446 let want = MDS_NOT_FOUND_ERROR;
447 let got = provider.error_message();
448 assert!(got.contains(want), "{got}, {provider:?}");
449 }
450
451 #[test_case(false, false)]
452 #[test_case(false, true)]
453 #[test_case(true, true)]
454 fn error_message_without_adc(adc: bool, overridden: bool) {
455 let provider = MDSAccessTokenProvider::builder()
456 .endpoint("http://127.0.0.1")
457 .created_by_adc(adc)
458 .endpoint_overridden(overridden)
459 .build();
460
461 let not_want = MDS_NOT_FOUND_ERROR;
462 let got = provider.error_message();
463 assert!(!got.contains(not_want), "{got}, {provider:?}");
464 }
465
466 #[tokio::test]
467 #[serial]
468 async fn adc_no_mds() -> TestResult {
469 let err = Builder::from_adc()
470 .build_token_provider()
471 .token()
472 .await
473 .unwrap_err();
474
475 assert!(err.is_transient(), "{err:?}");
476 assert!(
477 err.to_string().contains("application-default"),
478 "display={err}, debug={err:?}"
479 );
480 let source = err
481 .source()
482 .and_then(|e| e.downcast_ref::<reqwest::Error>());
483 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
484
485 Ok(())
486 }
487
488 #[tokio::test]
489 #[serial]
490 async fn adc_overridden_mds() -> TestResult {
491 let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
492
493 let err = Builder::from_adc()
494 .build_token_provider()
495 .token()
496 .await
497 .unwrap_err();
498
499 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
500
501 assert!(err.is_transient(), "{err:?}");
502 assert!(
503 !err.to_string().contains("application-default"),
504 "display={err}, debug={err:?}"
505 );
506 let source = err
507 .source()
508 .and_then(|e| e.downcast_ref::<reqwest::Error>());
509 assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
510
511 Ok(())
512 }
513
514 #[tokio::test]
515 #[serial]
516 async fn builder_no_mds() -> TestResult {
517 let e = Builder::default()
518 .build_token_provider()
519 .token()
520 .await
521 .err()
522 .unwrap();
523
524 assert!(e.is_transient(), "{e:?}");
525 assert!(
526 !format!("{:?}", e.source()).contains("application-default"),
527 "{e:?}"
528 );
529
530 Ok(())
531 }
532
533 fn handle_token_factory(
534 response_code: StatusCode,
535 response_headers: HeaderMap,
536 response_body: Value,
537 ) -> impl IntoResponse {
538 (response_code, response_headers, response_body.to_string()).into_response()
539 }
540
541 type Handlers = HashMap<String, (StatusCode, Value, TokenQueryParams, Arc<Mutex<i32>>)>;
542
543 async fn start(path_handlers: Handlers) -> (String, JoinHandle<()>) {
546 let mut app = axum::Router::new();
547
548 for (path, (code, body, expected_query, call_count)) in path_handlers {
549 let header_map = HeaderMap::new();
550 let handler = move |Query(query): Query<TokenQueryParams>| {
551 let body = body.clone();
552 let header_map = header_map.clone();
553 async move {
554 assert_eq!(expected_query, query);
555 let mut count = call_count.lock().unwrap();
556 *count += 1;
557 handle_token_factory(code, header_map, body)
558 }
559 };
560 app = app.route(&path, axum::routing::get(handler));
561 }
562
563 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
564 let addr = listener.local_addr().unwrap();
565 let server = tokio::spawn(async move {
566 axum::serve(listener, app).await.unwrap();
567 });
568 (format!("http://{}:{}", addr.ip(), addr.port()), server)
569 }
570
571 #[tokio::test]
572 #[serial]
573 async fn test_gce_metadata_host_env_var() {
574 let scopes = ["scope1".to_string(), "scope2".to_string()];
575 let response = MDSTokenResponse {
576 access_token: "test-access-token".to_string(),
577 expires_in: Some(3600),
578 token_type: "test-token-type".to_string(),
579 };
580 let response_body = serde_json::to_value(&response).unwrap();
581
582 let (endpoint, _server) = start(Handlers::from([(
583 format!("{}/token", MDS_DEFAULT_URI),
584 (
585 StatusCode::OK,
586 response_body,
587 TokenQueryParams {
588 scopes: Some(scopes.join(",")),
589 recursive: None,
590 },
591 Arc::new(Mutex::new(0)),
592 ),
593 )]))
594 .await;
595
596 let _e = ScopedEnv::set(
598 super::GCE_METADATA_HOST_ENV_VAR,
599 endpoint.strip_prefix("http://").unwrap_or(&endpoint),
600 );
601 let mdsc = Builder::default()
602 .with_scopes(["scope1", "scope2"])
603 .build()
604 .unwrap();
605 let headers = mdsc.headers(Extensions::new()).await.unwrap();
606 let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
607
608 assert_eq!(
609 get_token_from_headers(headers).unwrap(),
610 "test-access-token"
611 );
612 }
613
614 #[tokio::test]
615 #[parallel]
616 async fn headers_success_with_quota_project() -> TestResult {
617 let scopes = ["scope1".to_string(), "scope2".to_string()];
618 let response = MDSTokenResponse {
619 access_token: "test-access-token".to_string(),
620 expires_in: Some(3600),
621 token_type: "test-token-type".to_string(),
622 };
623 let response_body = serde_json::to_value(&response).unwrap();
624
625 let (endpoint, _server) = start(Handlers::from([(
626 format!("{}/token", MDS_DEFAULT_URI),
627 (
628 StatusCode::OK,
629 response_body,
630 TokenQueryParams {
631 scopes: Some(scopes.join(",")),
632 recursive: None,
633 },
634 Arc::new(Mutex::new(0)),
635 ),
636 )]))
637 .await;
638
639 let mdsc = Builder::default()
640 .with_scopes(["scope1", "scope2"])
641 .with_endpoint(endpoint)
642 .with_quota_project_id("test-project")
643 .build()?;
644
645 let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
646 let token = headers.get(AUTHORIZATION).unwrap();
647 let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
648
649 assert_eq!(headers.len(), 2, "{headers:?}");
650 assert_eq!(
651 token,
652 HeaderValue::from_static("test-token-type test-access-token")
653 );
654 assert!(token.is_sensitive());
655 assert_eq!(quota_project, HeaderValue::from_static("test-project"));
656 assert!(!quota_project.is_sensitive());
657
658 Ok(())
659 }
660
661 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
662 #[parallel]
663 async fn token_caching() -> TestResult {
664 let scopes = vec!["scope1".to_string()];
665 let response = MDSTokenResponse {
666 access_token: "test-access-token".to_string(),
667 expires_in: Some(3600),
668 token_type: "test-token-type".to_string(),
669 };
670 let response_body = serde_json::to_value(&response).unwrap();
671
672 let call_count = Arc::new(Mutex::new(0));
673 let (endpoint, _server) = start(Handlers::from([(
674 format!("{}/token", MDS_DEFAULT_URI),
675 (
676 StatusCode::OK,
677 response_body,
678 TokenQueryParams {
679 scopes: Some(scopes.join(",")),
680 recursive: None,
681 },
682 call_count.clone(),
683 ),
684 )]))
685 .await;
686
687 let mdsc = Builder::default()
688 .with_scopes(scopes)
689 .with_endpoint(endpoint)
690 .build()?;
691 let headers = mdsc.headers(Extensions::new()).await?;
692 assert_eq!(
693 get_token_from_headers(headers).unwrap(),
694 "test-access-token"
695 );
696 let headers = mdsc.headers(Extensions::new()).await?;
697 assert_eq!(
698 get_token_from_headers(headers).unwrap(),
699 "test-access-token"
700 );
701
702 assert_eq!(*call_count.lock().unwrap(), 1);
704
705 Ok(())
706 }
707
708 #[tokio::test(start_paused = true)]
709 #[parallel]
710 async fn token_provider_full() -> TestResult {
711 let scopes = vec!["scope1".to_string()];
712 let response = MDSTokenResponse {
713 access_token: "test-access-token".to_string(),
714 expires_in: Some(3600),
715 token_type: "test-token-type".to_string(),
716 };
717 let response_body = serde_json::to_value(&response).unwrap();
718
719 let (endpoint, _server) = start(Handlers::from([(
720 format!("{}/token", MDS_DEFAULT_URI),
721 (
722 StatusCode::OK,
723 response_body,
724 TokenQueryParams {
725 scopes: Some(scopes.join(",")),
726 recursive: None,
727 },
728 Arc::new(Mutex::new(0)),
729 ),
730 )]))
731 .await;
732 println!("endpoint = {endpoint}");
733
734 let token = Builder::default()
735 .with_endpoint(endpoint)
736 .with_scopes(scopes)
737 .build_token_provider()
738 .token()
739 .await?;
740
741 let now = tokio::time::Instant::now();
742 assert_eq!(token.token, "test-access-token");
743 assert_eq!(token.token_type, "test-token-type");
744 assert!(
745 token
746 .expires_at
747 .is_some_and(|d| d >= now + Duration::from_secs(3600))
748 );
749
750 Ok(())
751 }
752
753 #[tokio::test(start_paused = true)]
754 #[parallel]
755 async fn token_provider_full_no_scopes() -> TestResult {
756 let response = MDSTokenResponse {
757 access_token: "test-access-token".to_string(),
758 expires_in: Some(3600),
759 token_type: "test-token-type".to_string(),
760 };
761 let response_body = serde_json::to_value(&response).unwrap();
762
763 let (endpoint, _server) = start(Handlers::from([(
764 format!("{}/token", MDS_DEFAULT_URI),
765 (
766 StatusCode::OK,
767 response_body,
768 TokenQueryParams {
769 scopes: None,
770 recursive: None,
771 },
772 Arc::new(Mutex::new(0)),
773 ),
774 )]))
775 .await;
776 println!("endpoint = {endpoint}");
777 let token = Builder::default()
778 .with_endpoint(endpoint)
779 .build_token_provider()
780 .token()
781 .await?;
782
783 let now = Instant::now();
784 assert_eq!(token.token, "test-access-token");
785 assert_eq!(token.token_type, "test-token-type");
786 assert!(
787 token
788 .expires_at
789 .is_some_and(|d| d == now + Duration::from_secs(3600))
790 );
791
792 Ok(())
793 }
794
795 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
796 #[parallel]
797 async fn credential_provider_full() -> TestResult {
798 let scopes = vec!["scope1".to_string()];
799 let response = MDSTokenResponse {
800 access_token: "test-access-token".to_string(),
801 expires_in: None,
802 token_type: "test-token-type".to_string(),
803 };
804 let response_body = serde_json::to_value(&response).unwrap();
805 let (endpoint, _server) = start(Handlers::from([(
806 format!("{}/token", MDS_DEFAULT_URI),
807 (
808 StatusCode::OK,
809 response_body,
810 TokenQueryParams {
811 scopes: Some(scopes.join(",")),
812 recursive: None,
813 },
814 Arc::new(Mutex::new(0)),
815 ),
816 )]))
817 .await;
818 println!("endpoint = {endpoint}");
819
820 let mdsc = Builder::default()
821 .with_endpoint(endpoint)
822 .with_scopes(scopes)
823 .build()?;
824 let headers = mdsc.headers(Extensions::new()).await?;
825 assert_eq!(
826 get_token_from_headers(headers.clone()).unwrap(),
827 "test-access-token"
828 );
829 assert_eq!(
830 get_token_type_from_headers(headers).unwrap(),
831 "test-token-type"
832 );
833
834 Ok(())
835 }
836
837 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
838 #[parallel]
839 async fn credentials_headers_retryable_error() -> TestResult {
840 let scopes = vec!["scope1".to_string()];
841 let (endpoint, _server) = start(Handlers::from([(
842 format!("{}/token", MDS_DEFAULT_URI),
843 (
844 StatusCode::SERVICE_UNAVAILABLE,
845 serde_json::to_value("try again")?,
846 TokenQueryParams {
847 scopes: Some(scopes.join(",")),
848 recursive: None,
849 },
850 Arc::new(Mutex::new(0)),
851 ),
852 )]))
853 .await;
854
855 let mdsc = Builder::default()
856 .with_endpoint(endpoint)
857 .with_scopes(scopes)
858 .build()?;
859 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
860 assert!(err.is_transient());
861 assert!(err.to_string().contains("try again"), "{err:?}");
862 let source = err
863 .source()
864 .and_then(|e| e.downcast_ref::<reqwest::Error>());
865 assert!(
866 matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
867 "{err:?}"
868 );
869
870 Ok(())
871 }
872
873 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
874 #[parallel]
875 async fn credentials_headers_nonretryable_error() -> TestResult {
876 let scopes = vec!["scope1".to_string()];
877 let (endpoint, _server) = start(Handlers::from([(
878 format!("{}/token", MDS_DEFAULT_URI),
879 (
880 StatusCode::UNAUTHORIZED,
881 serde_json::to_value("epic fail".to_string())?,
882 TokenQueryParams {
883 scopes: Some(scopes.join(",")),
884 recursive: None,
885 },
886 Arc::new(Mutex::new(0)),
887 ),
888 )]))
889 .await;
890
891 let mdsc = Builder::default()
892 .with_endpoint(endpoint)
893 .with_scopes(scopes)
894 .build()?;
895
896 let err = mdsc.headers(Extensions::new()).await.unwrap_err();
897 assert!(!err.is_transient());
898 assert!(err.to_string().contains("epic fail"), "{err:?}");
899 let source = err
900 .source()
901 .and_then(|e| e.downcast_ref::<reqwest::Error>());
902 assert!(
903 matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
904 "{err:?}"
905 );
906
907 Ok(())
908 }
909
910 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
911 #[parallel]
912 async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
913 let scopes = vec!["scope1".to_string()];
914 let (endpoint, _server) = start(Handlers::from([(
915 format!("{}/token", MDS_DEFAULT_URI),
916 (
917 StatusCode::OK,
918 serde_json::to_value("bad json".to_string())?,
919 TokenQueryParams {
920 scopes: Some(scopes.join(",")),
921 recursive: None,
922 },
923 Arc::new(Mutex::new(0)),
924 ),
925 )]))
926 .await;
927
928 let mdsc = Builder::default()
929 .with_endpoint(endpoint)
930 .with_scopes(scopes)
931 .build()?;
932
933 let e = mdsc.headers(Extensions::new()).await.err().unwrap();
934 assert!(!e.is_transient());
935
936 Ok(())
937 }
938
939 #[tokio::test]
940 async fn get_default_universe_domain_success() -> TestResult {
941 let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
942 assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
943 Ok(())
944 }
945
946 #[tokio::test]
947 async fn get_custom_universe_domain_success() -> TestResult {
948 let universe_domain = "test-universe";
949 let universe_domain_response = Builder::default()
950 .with_universe_domain(universe_domain)
951 .build()?
952 .universe_domain()
953 .await
954 .unwrap();
955 assert_eq!(universe_domain_response, universe_domain);
956
957 Ok(())
958 }
959}