google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! While the Google Cloud client libraries for Rust default to
34//! using the types defined in this module. You may want to use said types directly
35//! to customize some of the properties of these credentials.
36//!
37//! Example usage:
38//!
39//! ```
40//! # use google_cloud_auth::credentials::mds::Builder;
41//! # use google_cloud_auth::credentials::Credentials;
42//! # use google_cloud_auth::errors::CredentialsError;
43//! # use http::Extensions;
44//! # tokio_test::block_on(async {
45//! let credentials: Credentials = Builder::default()
46//!     .with_quota_project_id("my-quota-project")
47//!     .build()?;
48//! let headers = credentials.headers(Extensions::new()).await?;
49//! println!("Headers: {headers:?}");
50//! # Ok::<(), CredentialsError>(())
51//! # });
52//! ```
53//!
54//! [Cloud Run]: https://cloud.google.com/run
55//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
56//! [gce-link]: https://cloud.google.com/products/compute
57//! [gke-link]: https://cloud.google.com/kubernetes-engine
58//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
59
60use crate::credentials::dynamic::CredentialsProvider;
61use crate::credentials::{Credentials, DEFAULT_UNIVERSE_DOMAIN, Result};
62use crate::errors::{self, CredentialsError, is_retryable};
63use crate::headers_util::build_bearer_headers;
64use crate::token::{CachedTokenProvider, Token, TokenProvider};
65use crate::token_cache::TokenCache;
66use async_trait::async_trait;
67use bon::Builder;
68use http::{Extensions, HeaderMap, HeaderValue};
69use reqwest::Client;
70use std::default::Default;
71use std::sync::Arc;
72use std::time::Duration;
73use tokio::time::Instant;
74
75const METADATA_FLAVOR_VALUE: &str = "Google";
76const METADATA_FLAVOR: &str = "metadata-flavor";
77const METADATA_ROOT: &str = "http://metadata.google.internal";
78const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
79const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
80
81#[derive(Debug)]
82struct MDSCredentials<T>
83where
84    T: CachedTokenProvider,
85{
86    quota_project_id: Option<String>,
87    universe_domain: Option<String>,
88    token_provider: T,
89}
90
91/// Creates [Credentials] instances backed by the [Metadata Service].
92///
93/// While the Google Cloud client libraries for Rust default to credentials
94/// backed by the metadata service, some applications may need to:
95/// * Customize the metadata service credentials in some way
96/// * Bypass the [Application Default Credentials] lookup and only
97///   use the metadata server credentials
98/// * Use the credentials directly outside the client libraries
99///
100/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
101/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
102#[derive(Debug, Default)]
103pub struct Builder {
104    endpoint: Option<String>,
105    quota_project_id: Option<String>,
106    scopes: Option<Vec<String>>,
107    universe_domain: Option<String>,
108}
109
110impl Builder {
111    /// Sets the endpoint for this credentials.
112    ///
113    /// A trailing slash is significant, so specify the base URL without a trailing  
114    /// slash. If not set, the credentials use `http://metadata.google.internal`.
115    ///
116    /// # Example
117    /// ```
118    /// # use google_cloud_auth::credentials::mds::Builder;
119    /// # tokio_test::block_on(async {
120    /// let credentials = Builder::default()
121    ///     .with_endpoint("https://metadata.google.foobar")
122    ///     .build();
123    /// # });
124    /// ```
125    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
126        self.endpoint = Some(endpoint.into());
127        self
128    }
129
130    /// Set the [quota project] for this credentials.
131    ///
132    /// In some services, you can use a service account in
133    /// one project for authentication and authorization, and charge
134    /// the usage to a different project. This may require that the
135    /// service account has `serviceusage.services.use` permissions on the quota project.
136    ///
137    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
138    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
139        self.quota_project_id = Some(quota_project_id.into());
140        self
141    }
142
143    /// Sets the universe domain for this credentials.
144    ///
145    /// Client libraries use `universe_domain` to determine
146    /// the API endpoints to use for making requests.
147    /// If not set, then credentials use `${service}.googleapis.com`,
148    /// otherwise they use `${service}.${universe_domain}.
149    pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
150        self.universe_domain = Some(universe_domain.into());
151        self
152    }
153
154    /// Sets the [scopes] for this credentials.
155    ///
156    /// Metadata server issues tokens based on the requested scopes.
157    /// If no scopes are specified, the credentials defaults to all
158    /// scopes configured for the [default service account] on the instance.
159    ///
160    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
161    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
162    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
163    where
164        I: IntoIterator<Item = S>,
165        S: Into<String>,
166    {
167        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
168        self
169    }
170
171    fn build_token_provider(self) -> MDSAccessTokenProvider {
172        let endpoint = match std::env::var(GCE_METADATA_HOST_ENV_VAR) {
173            Ok(endpoint) => format!("http://{}", endpoint),
174            _ => self.endpoint.clone().unwrap_or(METADATA_ROOT.to_string()),
175        };
176        MDSAccessTokenProvider::builder()
177            .endpoint(endpoint)
178            .maybe_scopes(self.scopes)
179            .build()
180    }
181
182    /// Returns a [Credentials] instance with the configured settings.
183    pub fn build(self) -> Result<Credentials> {
184        let mdsc = MDSCredentials {
185            quota_project_id: self.quota_project_id.clone(),
186            universe_domain: self.universe_domain.clone(),
187            token_provider: TokenCache::new(self.build_token_provider()),
188        };
189        Ok(Credentials {
190            inner: Arc::new(mdsc),
191        })
192    }
193}
194
195#[async_trait::async_trait]
196impl<T> CredentialsProvider for MDSCredentials<T>
197where
198    T: CachedTokenProvider,
199{
200    async fn headers(&self, extensions: Extensions) -> Result<HeaderMap> {
201        let token = self.token_provider.token(extensions).await?;
202        build_bearer_headers(&token, &self.quota_project_id)
203    }
204
205    async fn universe_domain(&self) -> Option<String> {
206        if self.universe_domain.is_some() {
207            return self.universe_domain.clone();
208        }
209        return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
210    }
211}
212
213#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
214struct ServiceAccountInfo {
215    email: String,
216    scopes: Option<Vec<String>>,
217    aliases: Option<Vec<String>>,
218}
219
220#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
221struct MDSTokenResponse {
222    access_token: String,
223    #[serde(skip_serializing_if = "Option::is_none")]
224    expires_in: Option<u64>,
225    token_type: String,
226}
227
228#[derive(Debug, Clone, Default, Builder)]
229struct MDSAccessTokenProvider {
230    #[builder(into)]
231    scopes: Option<Vec<String>>,
232    #[builder(into)]
233    endpoint: String,
234}
235
236#[async_trait]
237impl TokenProvider for MDSAccessTokenProvider {
238    async fn token(&self) -> Result<Token> {
239        let client = Client::new();
240        let request = client
241            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
242            .header(
243                METADATA_FLAVOR,
244                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
245            );
246        // Use the `scopes` option if set, otherwise let the MDS use the default
247        // scopes.
248        let scopes = self.scopes.as_ref().map(|v| v.join(","));
249        let request = scopes
250            .into_iter()
251            .fold(request, |r, s| r.query(&[("scopes", s)]));
252
253        let response = request.send().await.map_err(errors::retryable)?;
254        // Process the response
255        if !response.status().is_success() {
256            let status = response.status();
257            let body = response
258                .text()
259                .await
260                .map_err(|e| CredentialsError::new(is_retryable(status), e))?;
261            return Err(CredentialsError::from_str(
262                is_retryable(status),
263                format!("Failed to fetch token. {body}"),
264            ));
265        }
266        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
267            let retryable = !e.is_decode();
268            CredentialsError::new(retryable, e)
269        })?;
270        let token = Token {
271            token: response.access_token,
272            token_type: response.token_type,
273            expires_at: response
274                .expires_in
275                .map(|d| Instant::now() + Duration::from_secs(d)),
276            metadata: None,
277        };
278        Ok(token)
279    }
280}
281
282#[cfg(test)]
283mod test {
284    use super::*;
285    use crate::credentials::QUOTA_PROJECT_KEY;
286    use crate::credentials::test::{get_token_from_headers, get_token_type_from_headers};
287    use crate::token::test::MockTokenProvider;
288    use axum::extract::Query;
289    use axum::response::IntoResponse;
290    use http::header::AUTHORIZATION;
291    use reqwest::StatusCode;
292    use reqwest::header::HeaderMap;
293    use scoped_env::ScopedEnv;
294    use serde::Deserialize;
295    use serde_json::Value;
296    use serial_test::{parallel, serial};
297    use std::collections::HashMap;
298    use std::error::Error;
299    use std::sync::Mutex;
300    use tokio::task::JoinHandle;
301    use url::Url;
302
303    type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
304
305    // Define a struct to capture query parameters
306    #[derive(Debug, Clone, Deserialize, PartialEq)]
307    struct TokenQueryParams {
308        scopes: Option<String>,
309        recursive: Option<String>,
310    }
311
312    #[test]
313    fn validate_default_endpoint_urls() {
314        let default_endpoint_address = Url::parse(&format!("{}{}", METADATA_ROOT, MDS_DEFAULT_URI));
315        assert!(default_endpoint_address.is_ok());
316
317        let token_endpoint_address =
318            Url::parse(&format!("{}{}/token", METADATA_ROOT, MDS_DEFAULT_URI));
319        assert!(token_endpoint_address.is_ok());
320    }
321
322    #[tokio::test]
323    async fn headers_success() {
324        let token = Token {
325            token: "test-token".to_string(),
326            token_type: "Bearer".to_string(),
327            expires_at: None,
328            metadata: None,
329        };
330
331        let mut mock = MockTokenProvider::new();
332        mock.expect_token().times(1).return_once(|| Ok(token));
333
334        let mdsc = MDSCredentials {
335            quota_project_id: None,
336            universe_domain: None,
337            token_provider: TokenCache::new(mock),
338        };
339        let headers = mdsc.headers(Extensions::new()).await.unwrap();
340        let token = headers.get(AUTHORIZATION).unwrap();
341
342        assert_eq!(headers.len(), 1, "{headers:?}");
343        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
344        assert!(token.is_sensitive());
345    }
346
347    #[tokio::test]
348    async fn headers_failure() {
349        let mut mock = MockTokenProvider::new();
350        mock.expect_token()
351            .times(1)
352            .return_once(|| Err(errors::non_retryable_from_str("fail")));
353
354        let mdsc = MDSCredentials {
355            quota_project_id: None,
356            universe_domain: None,
357            token_provider: TokenCache::new(mock),
358        };
359        assert!(mdsc.headers(Extensions::new()).await.is_err());
360    }
361
362    fn handle_token_factory(
363        response_code: StatusCode,
364        response_headers: HeaderMap,
365        response_body: Value,
366    ) -> impl IntoResponse {
367        (response_code, response_headers, response_body.to_string()).into_response()
368    }
369
370    type Handlers = HashMap<String, (StatusCode, Value, TokenQueryParams, Arc<Mutex<i32>>)>;
371
372    // Starts a server running locally that responds on multiple paths.
373    // Returns an (endpoint, server) pair.
374    async fn start(path_handlers: Handlers) -> (String, JoinHandle<()>) {
375        let mut app = axum::Router::new();
376
377        for (path, (code, body, expected_query, call_count)) in path_handlers {
378            let header_map = HeaderMap::new();
379            let handler = move |Query(query): Query<TokenQueryParams>| {
380                let body = body.clone();
381                let header_map = header_map.clone();
382                async move {
383                    assert_eq!(expected_query, query);
384                    let mut count = call_count.lock().unwrap();
385                    *count += 1;
386                    handle_token_factory(code, header_map, body)
387                }
388            };
389            app = app.route(&path, axum::routing::get(handler));
390        }
391
392        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
393        let addr = listener.local_addr().unwrap();
394        let server = tokio::spawn(async move {
395            axum::serve(listener, app).await.unwrap();
396        });
397        (format!("http://{}:{}", addr.ip(), addr.port()), server)
398    }
399
400    #[tokio::test]
401    #[serial]
402    async fn test_gce_metadata_host_env_var() {
403        let scopes = ["scope1".to_string(), "scope2".to_string()];
404        let response = MDSTokenResponse {
405            access_token: "test-access-token".to_string(),
406            expires_in: Some(3600),
407            token_type: "test-token-type".to_string(),
408        };
409        let response_body = serde_json::to_value(&response).unwrap();
410
411        let (endpoint, _server) = start(Handlers::from([(
412            format!("{}/token", MDS_DEFAULT_URI),
413            (
414                StatusCode::OK,
415                response_body,
416                TokenQueryParams {
417                    scopes: Some(scopes.join(",")),
418                    recursive: None,
419                },
420                Arc::new(Mutex::new(0)),
421            ),
422        )]))
423        .await;
424
425        // Trim out 'http://' from the endpoint provided by the fake server
426        let _e = ScopedEnv::set(
427            super::GCE_METADATA_HOST_ENV_VAR,
428            endpoint.strip_prefix("http://").unwrap_or(&endpoint),
429        );
430        let mdsc = Builder::default()
431            .with_scopes(["scope1", "scope2"])
432            .build()
433            .unwrap();
434        let headers = mdsc.headers(Extensions::new()).await.unwrap();
435        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
436
437        assert_eq!(
438            get_token_from_headers(&headers).unwrap(),
439            "test-access-token"
440        );
441    }
442
443    #[tokio::test]
444    #[parallel]
445    async fn headers_success_with_quota_project() -> TestResult {
446        let scopes = ["scope1".to_string(), "scope2".to_string()];
447        let response = MDSTokenResponse {
448            access_token: "test-access-token".to_string(),
449            expires_in: Some(3600),
450            token_type: "test-token-type".to_string(),
451        };
452        let response_body = serde_json::to_value(&response).unwrap();
453
454        let (endpoint, _server) = start(Handlers::from([(
455            format!("{}/token", MDS_DEFAULT_URI),
456            (
457                StatusCode::OK,
458                response_body,
459                TokenQueryParams {
460                    scopes: Some(scopes.join(",")),
461                    recursive: None,
462                },
463                Arc::new(Mutex::new(0)),
464            ),
465        )]))
466        .await;
467
468        let mdsc = Builder::default()
469            .with_scopes(["scope1", "scope2"])
470            .with_endpoint(endpoint)
471            .with_quota_project_id("test-project")
472            .build()?;
473
474        let headers = mdsc.headers(Extensions::new()).await.unwrap();
475        let token = headers.get(AUTHORIZATION).unwrap();
476        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
477
478        assert_eq!(headers.len(), 2, "{headers:?}");
479        assert_eq!(
480            token,
481            HeaderValue::from_static("test-token-type test-access-token")
482        );
483        assert!(token.is_sensitive());
484        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
485        assert!(!quota_project.is_sensitive());
486        Ok(())
487    }
488
489    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
490    #[parallel]
491    async fn token_caching() -> TestResult {
492        let scopes = vec!["scope1".to_string()];
493        let response = MDSTokenResponse {
494            access_token: "test-access-token".to_string(),
495            expires_in: Some(3600),
496            token_type: "test-token-type".to_string(),
497        };
498        let response_body = serde_json::to_value(&response).unwrap();
499
500        let call_count = Arc::new(Mutex::new(0));
501        let (endpoint, _server) = start(Handlers::from([(
502            format!("{}/token", MDS_DEFAULT_URI),
503            (
504                StatusCode::OK,
505                response_body,
506                TokenQueryParams {
507                    scopes: Some(scopes.join(",")),
508                    recursive: None,
509                },
510                call_count.clone(),
511            ),
512        )]))
513        .await;
514
515        let mdsc = Builder::default()
516            .with_scopes(scopes)
517            .with_endpoint(endpoint)
518            .build()?;
519        let headers = mdsc.headers(Extensions::new()).await?;
520        assert_eq!(
521            get_token_from_headers(&headers).unwrap(),
522            "test-access-token"
523        );
524        let headers = mdsc.headers(Extensions::new()).await?;
525        assert_eq!(
526            get_token_from_headers(&headers).unwrap(),
527            "test-access-token"
528        );
529
530        // validate that the inner token provider is called only once
531        assert_eq!(*call_count.lock().unwrap(), 1);
532
533        Ok(())
534    }
535
536    #[tokio::test(start_paused = true)]
537    #[parallel]
538    async fn token_provider_full() -> TestResult {
539        let scopes = vec!["scope1".to_string()];
540        let response = MDSTokenResponse {
541            access_token: "test-access-token".to_string(),
542            expires_in: Some(3600),
543            token_type: "test-token-type".to_string(),
544        };
545        let response_body = serde_json::to_value(&response).unwrap();
546
547        let (endpoint, _server) = start(Handlers::from([(
548            format!("{}/token", MDS_DEFAULT_URI),
549            (
550                StatusCode::OK,
551                response_body,
552                TokenQueryParams {
553                    scopes: Some(scopes.join(",")),
554                    recursive: None,
555                },
556                Arc::new(Mutex::new(0)),
557            ),
558        )]))
559        .await;
560        println!("endpoint = {endpoint}");
561
562        let token = Builder::default()
563            .with_endpoint(endpoint)
564            .with_scopes(scopes)
565            .build_token_provider()
566            .token()
567            .await?;
568
569        let now = tokio::time::Instant::now();
570        assert_eq!(token.token, "test-access-token");
571        assert_eq!(token.token_type, "test-token-type");
572        assert!(
573            token
574                .expires_at
575                .is_some_and(|d| d >= now + Duration::from_secs(3600))
576        );
577
578        Ok(())
579    }
580
581    #[tokio::test(start_paused = true)]
582    #[parallel]
583    async fn token_provider_full_no_scopes() -> TestResult {
584        let response = MDSTokenResponse {
585            access_token: "test-access-token".to_string(),
586            expires_in: Some(3600),
587            token_type: "test-token-type".to_string(),
588        };
589        let response_body = serde_json::to_value(&response).unwrap();
590
591        let (endpoint, _server) = start(Handlers::from([(
592            format!("{}/token", MDS_DEFAULT_URI),
593            (
594                StatusCode::OK,
595                response_body,
596                TokenQueryParams {
597                    scopes: None,
598                    recursive: None,
599                },
600                Arc::new(Mutex::new(0)),
601            ),
602        )]))
603        .await;
604        println!("endpoint = {endpoint}");
605        let token = Builder::default()
606            .with_endpoint(endpoint)
607            .build_token_provider()
608            .token()
609            .await?;
610
611        let now = Instant::now();
612        assert_eq!(token.token, "test-access-token");
613        assert_eq!(token.token_type, "test-token-type");
614        assert!(
615            token
616                .expires_at
617                .is_some_and(|d| d == now + Duration::from_secs(3600))
618        );
619
620        Ok(())
621    }
622
623    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
624    #[parallel]
625    async fn credential_provider_full() -> TestResult {
626        let scopes = vec!["scope1".to_string()];
627        let response = MDSTokenResponse {
628            access_token: "test-access-token".to_string(),
629            expires_in: None,
630            token_type: "test-token-type".to_string(),
631        };
632        let response_body = serde_json::to_value(&response).unwrap();
633        let (endpoint, _server) = start(Handlers::from([(
634            format!("{}/token", MDS_DEFAULT_URI),
635            (
636                StatusCode::OK,
637                response_body,
638                TokenQueryParams {
639                    scopes: Some(scopes.join(",")),
640                    recursive: None,
641                },
642                Arc::new(Mutex::new(0)),
643            ),
644        )]))
645        .await;
646        println!("endpoint = {endpoint}");
647
648        let mdsc = Builder::default()
649            .with_endpoint(endpoint)
650            .with_scopes(scopes)
651            .build()?;
652        let headers = mdsc.headers(Extensions::new()).await?;
653        assert_eq!(
654            get_token_from_headers(&headers).unwrap(),
655            "test-access-token"
656        );
657        assert_eq!(
658            get_token_type_from_headers(&headers).unwrap(),
659            "test-token-type"
660        );
661
662        Ok(())
663    }
664
665    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
666    #[parallel]
667    async fn credentials_headers_retryable_error() -> TestResult {
668        let scopes = vec!["scope1".to_string()];
669        let (endpoint, _server) = start(Handlers::from([(
670            format!("{}/token", MDS_DEFAULT_URI),
671            (
672                StatusCode::SERVICE_UNAVAILABLE,
673                serde_json::to_value("try again")?,
674                TokenQueryParams {
675                    scopes: Some(scopes.join(",")),
676                    recursive: None,
677                },
678                Arc::new(Mutex::new(0)),
679            ),
680        )]))
681        .await;
682
683        let mdsc = Builder::default()
684            .with_endpoint(endpoint)
685            .with_scopes(scopes)
686            .build()?;
687        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
688        assert!(e.is_retryable());
689        assert!(e.source().unwrap().to_string().contains("try again"));
690
691        Ok(())
692    }
693
694    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
695    #[parallel]
696    async fn credentials_headers_nonretryable_error() -> TestResult {
697        let scopes = vec!["scope1".to_string()];
698        let (endpoint, _server) = start(Handlers::from([(
699            format!("{}/token", MDS_DEFAULT_URI),
700            (
701                StatusCode::UNAUTHORIZED,
702                serde_json::to_value("epic fail".to_string())?,
703                TokenQueryParams {
704                    scopes: Some(scopes.join(",")),
705                    recursive: None,
706                },
707                Arc::new(Mutex::new(0)),
708            ),
709        )]))
710        .await;
711
712        let mdsc = Builder::default()
713            .with_endpoint(endpoint)
714            .with_scopes(scopes)
715            .build()?;
716
717        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
718        assert!(!e.is_retryable());
719        assert!(e.source().unwrap().to_string().contains("epic fail"));
720
721        Ok(())
722    }
723
724    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
725    #[parallel]
726    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
727        let scopes = vec!["scope1".to_string()];
728        let (endpoint, _server) = start(Handlers::from([(
729            format!("{}/token", MDS_DEFAULT_URI),
730            (
731                StatusCode::OK,
732                serde_json::to_value("bad json".to_string())?,
733                TokenQueryParams {
734                    scopes: Some(scopes.join(",")),
735                    recursive: None,
736                },
737                Arc::new(Mutex::new(0)),
738            ),
739        )]))
740        .await;
741
742        let mdsc = Builder::default()
743            .with_endpoint(endpoint)
744            .with_scopes(scopes)
745            .build()?;
746
747        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
748        assert!(!e.is_retryable());
749
750        Ok(())
751    }
752
753    #[tokio::test]
754    async fn get_default_universe_domain_success() -> TestResult {
755        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
756        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
757        Ok(())
758    }
759
760    #[tokio::test]
761    async fn get_custom_universe_domain_success() -> TestResult {
762        let universe_domain = "test-universe";
763        let universe_domain_response = Builder::default()
764            .with_universe_domain(universe_domain)
765            .build()?
766            .universe_domain()
767            .await
768            .unwrap();
769        assert_eq!(universe_domain_response, universe_domain);
770
771        Ok(())
772    }
773}