google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! You can use this access token to securely authenticate with Google Cloud,
25//! without having to download secrets or other credentials. The types in this
26//! module allow you to retrieve these access tokens, and can be used with
27//! the Google Cloud client libraries for Rust.
28//!
29//! While the Google Cloud client libraries for Rust default to
30//! using the types defined in this module. You may want to use said types directly
31//! to customize some of the properties of these credentials.
32//!
33//! Example usage:
34//!
35//! ```
36//! # use google_cloud_auth::credentials::mds::Builder;
37//! # use google_cloud_auth::credentials::Credentials;
38//! # use google_cloud_auth::errors::CredentialsError;
39//! # tokio_test::block_on(async {
40//! let credentials: Credentials = Builder::default()
41//!     .with_quota_project_id("my-quota-project")
42//!     .build();
43//! let token = credentials.token().await?;
44//! println!("Token: {}", token.token);
45//! # Ok::<(), CredentialsError>(())
46//! # });
47//! ```
48//!
49//! [Cloud Run]: https://cloud.google.com/run
50//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
51//! [gce-link]: https://cloud.google.com/products/compute
52//! [gke-link]: https://cloud.google.com/kubernetes-engine
53//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
54
55use crate::credentials::dynamic::CredentialsTrait;
56use crate::credentials::{Credentials, DEFAULT_UNIVERSE_DOMAIN, Result};
57use crate::errors::{self, CredentialsError, is_retryable};
58use crate::headers_util::build_bearer_headers;
59use crate::token::{Token, TokenProvider};
60use async_trait::async_trait;
61use bon::Builder;
62use http::header::{HeaderName, HeaderValue};
63use reqwest::Client;
64use std::default::Default;
65use std::sync::Arc;
66use std::time::Duration;
67
68const METADATA_FLAVOR_VALUE: &str = "Google";
69const METADATA_FLAVOR: &str = "metadata-flavor";
70const METADATA_ROOT: &str = "http://metadata.google.internal/";
71
72pub(crate) fn new() -> Credentials {
73    Builder::default().build()
74}
75
76#[derive(Debug)]
77struct MDSCredentials<T>
78where
79    T: TokenProvider,
80{
81    quota_project_id: Option<String>,
82    universe_domain: Option<String>,
83    token_provider: T,
84}
85
86/// Creates [Credentials] instances backed by the [Metadata Service].
87///
88/// While the Google Cloud client libraries for Rust default to credentials
89/// backed by the metadata service, some applications may need to:
90/// * Customize the metadata service credentials in some way
91/// * Bypass the [Application Default Credentials] lookup and only
92///   use the metadata server credentials
93/// * Use the credentials directly outside the client libraries
94///
95/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
96/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
97#[derive(Debug, Default)]
98pub struct Builder {
99    endpoint: Option<String>,
100    quota_project_id: Option<String>,
101    scopes: Option<Vec<String>>,
102    universe_domain: Option<String>,
103}
104
105impl Builder {
106    /// Sets the endpoint for this credentials.
107    ///
108    /// If not set, the credentials use `http://metadata.google.internal/`.
109    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
110        self.endpoint = Some(endpoint.into());
111        self
112    }
113
114    /// Set the [quota project] for this credentials.
115    ///
116    /// In some services, you can use a service account in
117    /// one project for authentication and authorization, and charge
118    /// the usage to a different project. This may require that the
119    /// service account has `serviceusage.services.use` permissions on the quota project.
120    ///
121    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
122    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
123        self.quota_project_id = Some(quota_project_id.into());
124        self
125    }
126
127    /// Sets the universe domain for this credentials.
128    ///
129    /// Client libraries use `universe_domain` to determine
130    /// the API endpoints to use for making requests.
131    /// If not set, then credentials use `${service}.googleapis.com`,
132    /// otherwise they use `${service}.${universe_domain}.
133    pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
134        self.universe_domain = Some(universe_domain.into());
135        self
136    }
137
138    /// Sets the [scopes] for this credentials.
139    ///
140    /// Metadata server issues tokens based on the requested scopes.
141    /// If no scopes are specified, the credentials defaults to all
142    /// scopes configured for the [default service account] on the instance.
143    ///
144    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
145    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
146    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
147    where
148        I: IntoIterator<Item = S>,
149        S: Into<String>,
150    {
151        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
152        self
153    }
154
155    /// Returns a [Credentials] instance with the configured settings.
156    pub fn build(self) -> Credentials {
157        let endpoint = self.endpoint.clone().unwrap_or(METADATA_ROOT.to_string());
158
159        let token_provider = MDSAccessTokenProvider::builder()
160            .endpoint(endpoint)
161            .maybe_scopes(self.scopes)
162            .build();
163        let cached_token_provider = crate::token_cache::TokenCache::new(token_provider);
164
165        let mdsc = MDSCredentials {
166            quota_project_id: self.quota_project_id,
167            token_provider: cached_token_provider,
168            universe_domain: self.universe_domain,
169        };
170        Credentials {
171            inner: Arc::new(mdsc),
172        }
173    }
174}
175
176#[async_trait::async_trait]
177impl<T> CredentialsTrait for MDSCredentials<T>
178where
179    T: TokenProvider,
180{
181    async fn token(&self) -> Result<Token> {
182        self.token_provider.token().await
183    }
184
185    async fn headers(&self) -> Result<Vec<(HeaderName, HeaderValue)>> {
186        let token = self.token().await?;
187        build_bearer_headers(&token, &self.quota_project_id)
188    }
189
190    async fn universe_domain(&self) -> Option<String> {
191        if self.universe_domain.is_some() {
192            return self.universe_domain.clone();
193        }
194        return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
195    }
196}
197
198#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
199struct ServiceAccountInfo {
200    email: String,
201    scopes: Option<Vec<String>>,
202    aliases: Option<Vec<String>>,
203}
204
205#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
206struct MDSTokenResponse {
207    access_token: String,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    expires_in: Option<u64>,
210    token_type: String,
211}
212
213#[derive(Debug, Clone, Default, Builder)]
214struct MDSAccessTokenProvider {
215    #[builder(into)]
216    scopes: Option<Vec<String>>,
217    #[builder(into)]
218    endpoint: String,
219}
220
221impl MDSAccessTokenProvider {
222    async fn get_service_account_info(&self, client: &Client) -> Result<ServiceAccountInfo> {
223        let request = client
224            .get(format!(
225                "{}/computeMetadata/v1/instance/service-accounts/default/",
226                self.endpoint
227            ))
228            .query(&[("recursive", "true")])
229            .header(
230                METADATA_FLAVOR,
231                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
232            );
233
234        let response = request.send().await.map_err(errors::retryable)?;
235
236        response
237            .json::<ServiceAccountInfo>()
238            .await
239            .map_err(errors::non_retryable)
240    }
241}
242
243#[async_trait]
244impl TokenProvider for MDSAccessTokenProvider {
245    async fn token(&self) -> Result<Token> {
246        let client = Client::new();
247        // Determine scopes, fetching from metadata server if needed.
248        let scopes = match &self.scopes {
249            Some(s) => s.clone().join(","),
250            None => {
251                let service_account_info = self.get_service_account_info(&client).await?;
252                service_account_info.scopes.unwrap_or_default().join(",")
253            }
254        };
255
256        let request = client
257            .get(format!(
258                "{}/computeMetadata/v1/instance/service-accounts/default/token",
259                self.endpoint
260            ))
261            .query(&[("scopes", scopes)])
262            .header(
263                METADATA_FLAVOR,
264                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
265            );
266
267        let response = request.send().await.map_err(errors::retryable)?;
268        // Process the response
269        if !response.status().is_success() {
270            let status = response.status();
271            let body = response
272                .text()
273                .await
274                .map_err(|e| CredentialsError::new(is_retryable(status), e))?;
275            return Err(CredentialsError::from_str(
276                is_retryable(status),
277                format!("Failed to fetch token. {body}"),
278            ));
279        }
280        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
281            let retryable = !e.is_decode();
282            CredentialsError::new(retryable, e)
283        })?;
284        let token = Token {
285            token: response.access_token,
286            token_type: response.token_type,
287            expires_at: response
288                .expires_in
289                .map(|d| std::time::Instant::now() + Duration::from_secs(d)),
290            metadata: None,
291        };
292        Ok(token)
293    }
294}
295
296#[cfg(test)]
297mod test {
298    use super::*;
299    use crate::credentials::QUOTA_PROJECT_KEY;
300    use crate::credentials::test::HV;
301    use crate::token::test::MockTokenProvider;
302    use axum::extract::Query;
303    use axum::response::IntoResponse;
304    use http::header::AUTHORIZATION;
305    use reqwest::StatusCode;
306    use reqwest::header::HeaderMap;
307    use serde::Deserialize;
308    use serde_json::Value;
309    use std::collections::HashMap;
310    use std::error::Error;
311    use std::sync::Mutex;
312    use tokio::task::JoinHandle;
313
314    type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
315    const MDS_TOKEN_URI: &str = "/computeMetadata/v1/instance/service-accounts/default/token";
316
317    // Define a struct to capture query parameters
318    #[derive(Debug, Clone, Deserialize, PartialEq)]
319    struct TokenQueryParams {
320        scopes: Option<String>,
321        recursive: Option<String>,
322    }
323
324    #[tokio::test]
325    async fn token_success() {
326        let expected = Token {
327            token: "test-token".to_string(),
328            token_type: "Bearer".to_string(),
329            expires_at: None,
330            metadata: None,
331        };
332        let expected_clone = expected.clone();
333
334        let mut mock = MockTokenProvider::new();
335        mock.expect_token()
336            .times(1)
337            .return_once(|| Ok(expected_clone));
338
339        let mdsc = MDSCredentials {
340            quota_project_id: None,
341            universe_domain: None,
342            token_provider: mock,
343        };
344        let actual = mdsc.token().await.unwrap();
345        assert_eq!(actual, expected);
346    }
347
348    #[tokio::test]
349    async fn token_failure() {
350        let mut mock = MockTokenProvider::new();
351        mock.expect_token()
352            .times(1)
353            .return_once(|| Err(errors::non_retryable_from_str("fail")));
354
355        let mdsc = MDSCredentials {
356            quota_project_id: None,
357            universe_domain: None,
358            token_provider: mock,
359        };
360        assert!(mdsc.token().await.is_err());
361    }
362
363    #[tokio::test]
364    async fn headers_success() {
365        let token = Token {
366            token: "test-token".to_string(),
367            token_type: "Bearer".to_string(),
368            expires_at: None,
369            metadata: None,
370        };
371
372        let mut mock = MockTokenProvider::new();
373        mock.expect_token().times(1).return_once(|| Ok(token));
374
375        let mdsc = MDSCredentials {
376            quota_project_id: None,
377            universe_domain: None,
378            token_provider: mock,
379        };
380        let headers: Vec<HV> = HV::from(mdsc.headers().await.unwrap());
381
382        assert_eq!(
383            headers,
384            vec![HV {
385                header: AUTHORIZATION.to_string(),
386                value: "Bearer test-token".to_string(),
387                is_sensitive: true,
388            }]
389        );
390    }
391
392    #[tokio::test]
393    async fn headers_failure() {
394        let mut mock = MockTokenProvider::new();
395        mock.expect_token()
396            .times(1)
397            .return_once(|| Err(errors::non_retryable_from_str("fail")));
398
399        let mdsc = MDSCredentials {
400            quota_project_id: None,
401            universe_domain: None,
402            token_provider: mock,
403        };
404        assert!(mdsc.headers().await.is_err());
405    }
406
407    fn handle_token_factory(
408        response_code: StatusCode,
409        response_headers: HeaderMap,
410        response_body: Value,
411    ) -> impl IntoResponse {
412        (response_code, response_headers, response_body.to_string()).into_response()
413    }
414
415    type Handlers = HashMap<String, (StatusCode, Value, TokenQueryParams, Arc<Mutex<i32>>)>;
416
417    // Starts a server running locally that responds on multiple paths.
418    // Returns an (endpoint, server) pair.
419    async fn start(path_handlers: Handlers) -> (String, JoinHandle<()>) {
420        let mut app = axum::Router::new();
421
422        for (path, (code, body, expected_query, call_count)) in path_handlers {
423            let header_map = HeaderMap::new();
424            let handler = move |Query(query): Query<TokenQueryParams>| {
425                let body = body.clone();
426                let header_map = header_map.clone();
427                async move {
428                    assert_eq!(expected_query, query);
429                    let mut count = call_count.lock().unwrap();
430                    *count += 1;
431                    handle_token_factory(code, header_map, body)
432                }
433            };
434            app = app.route(&path, axum::routing::get(handler));
435        }
436
437        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
438        let addr = listener.local_addr().unwrap();
439        let server = tokio::spawn(async move {
440            axum::serve(listener, app).await.unwrap();
441        });
442        (format!("http://{}:{}", addr.ip(), addr.port()), server)
443    }
444
445    #[tokio::test]
446    async fn get_default_service_account_info_success() {
447        let service_account = "default";
448        let path = format!(
449            "/computeMetadata/v1/instance/service-accounts/{}/",
450            service_account
451        );
452        let service_account_info = ServiceAccountInfo {
453            email: "test@test.com".to_string(),
454            scopes: Some(vec!["scope 1".to_string(), "scope 2".to_string()]),
455            aliases: None,
456        };
457        let service_account_info_json = serde_json::to_value(service_account_info.clone()).unwrap();
458        let (endpoint, _server) = start(Handlers::from([(
459            path,
460            (
461                StatusCode::OK,
462                service_account_info_json,
463                TokenQueryParams {
464                    scopes: None,
465                    recursive: Some("true".to_string()),
466                },
467                Arc::new(Mutex::new(0)),
468            ),
469        )]))
470        .await;
471
472        let request = Client::new();
473        let token_provider = MDSAccessTokenProvider::builder().endpoint(endpoint).build();
474
475        let result = token_provider.get_service_account_info(&request).await;
476
477        assert!(result.is_ok());
478        assert_eq!(result.unwrap(), service_account_info);
479    }
480
481    #[tokio::test]
482    async fn get_service_account_info_server_error() {
483        let path = "/computeMetadata/v1/instance/service-accounts/default/".to_string();
484        let (endpoint, _server) = start(Handlers::from([(
485            path,
486            (
487                StatusCode::SERVICE_UNAVAILABLE,
488                serde_json::to_value("try again").unwrap(),
489                TokenQueryParams {
490                    scopes: None,
491                    recursive: Some("true".to_string()),
492                },
493                Arc::new(Mutex::new(0)),
494            ),
495        )]))
496        .await;
497
498        let request = Client::new();
499        let token_provider = MDSAccessTokenProvider::builder().endpoint(endpoint).build();
500
501        let result = token_provider.get_service_account_info(&request).await;
502        assert!(result.is_err());
503    }
504
505    #[tokio::test]
506    async fn headers_success_with_quota_project() {
507        let scopes = ["scope1".to_string(), "scope2".to_string()];
508        let response = MDSTokenResponse {
509            access_token: "test-access-token".to_string(),
510            expires_in: Some(3600),
511            token_type: "test-token-type".to_string(),
512        };
513        let response_body = serde_json::to_value(&response).unwrap();
514
515        let (endpoint, _server) = start(Handlers::from([(
516            MDS_TOKEN_URI.to_string(),
517            (
518                StatusCode::OK,
519                response_body,
520                TokenQueryParams {
521                    scopes: Some(scopes.join(",")),
522                    recursive: None,
523                },
524                Arc::new(Mutex::new(0)),
525            ),
526        )]))
527        .await;
528
529        let mdsc = Builder::default()
530            .with_scopes(["scope1", "scope2"])
531            .with_endpoint(endpoint)
532            .with_quota_project_id("test-project")
533            .build();
534
535        let headers: Vec<HV> = HV::from(mdsc.headers().await.unwrap());
536        assert_eq!(
537            headers,
538            vec![
539                HV {
540                    header: AUTHORIZATION.to_string(),
541                    value: "test-token-type test-access-token".to_string(),
542                    is_sensitive: true,
543                },
544                HV {
545                    header: QUOTA_PROJECT_KEY.to_string(),
546                    value: "test-project".to_string(),
547                    is_sensitive: false,
548                }
549            ]
550        );
551    }
552
553    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
554    async fn token_caching() -> TestResult {
555        let scopes = vec!["scope1".to_string()];
556        let response = MDSTokenResponse {
557            access_token: "test-access-token".to_string(),
558            expires_in: Some(3600),
559            token_type: "test-token-type".to_string(),
560        };
561        let response_body = serde_json::to_value(&response).unwrap();
562
563        let call_count = Arc::new(Mutex::new(0));
564        let (endpoint, _server) = start(Handlers::from([(
565            MDS_TOKEN_URI.to_string(),
566            (
567                StatusCode::OK,
568                response_body,
569                TokenQueryParams {
570                    scopes: Some(scopes.join(",")),
571                    recursive: None,
572                },
573                call_count.clone(),
574            ),
575        )]))
576        .await;
577
578        let mdsc = Builder::default()
579            .with_scopes(scopes)
580            .with_endpoint(endpoint)
581            .build();
582        let token = mdsc.token().await?;
583        assert_eq!(token.token, "test-access-token");
584        let token = mdsc.token().await?;
585        assert_eq!(token.token, "test-access-token");
586
587        // validate that the inner token provider is called only once
588        assert_eq!(*call_count.lock().unwrap(), 1);
589
590        Ok(())
591    }
592
593    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
594    async fn token_provider_full() -> TestResult {
595        let scopes = vec!["scope1".to_string()];
596        let response = MDSTokenResponse {
597            access_token: "test-access-token".to_string(),
598            expires_in: Some(3600),
599            token_type: "test-token-type".to_string(),
600        };
601        let response_body = serde_json::to_value(&response).unwrap();
602
603        let (endpoint, _server) = start(Handlers::from([(
604            MDS_TOKEN_URI.to_string(),
605            (
606                StatusCode::OK,
607                response_body,
608                TokenQueryParams {
609                    scopes: Some(scopes.join(",")),
610                    recursive: None,
611                },
612                Arc::new(Mutex::new(0)),
613            ),
614        )]))
615        .await;
616        println!("endpoint = {endpoint}");
617
618        let mdsc = Builder::default()
619            .with_scopes(scopes)
620            .with_endpoint(endpoint)
621            .build();
622        let now = std::time::Instant::now();
623        let token = mdsc.token().await?;
624        assert_eq!(token.token, "test-access-token");
625        assert_eq!(token.token_type, "test-token-type");
626        assert!(
627            token
628                .expires_at
629                .is_some_and(|d| d >= now + Duration::from_secs(3600))
630        );
631
632        Ok(())
633    }
634
635    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
636    async fn token_provider_full_no_scopes() -> TestResult {
637        let service_account_info_path =
638            "/computeMetadata/v1/instance/service-accounts/default/".to_string();
639        let scopes = vec!["scope 1".to_string(), "scope 2".to_string()];
640        let service_account_info = ServiceAccountInfo {
641            email: "test@test.com".to_string(),
642            scopes: Some(scopes.clone()),
643            aliases: None,
644        };
645        let service_account_info_json = serde_json::to_value(service_account_info.clone()).unwrap();
646
647        let response = MDSTokenResponse {
648            access_token: "test-access-token".to_string(),
649            expires_in: Some(3600),
650            token_type: "test-token-type".to_string(),
651        };
652        let response_body = serde_json::to_value(&response).unwrap();
653
654        let (endpoint, _server) = start(Handlers::from([
655            (
656                service_account_info_path,
657                (
658                    StatusCode::OK,
659                    service_account_info_json,
660                    TokenQueryParams {
661                        scopes: None,
662                        recursive: Some("true".to_string()),
663                    },
664                    Arc::new(Mutex::new(0)),
665                ),
666            ),
667            (
668                MDS_TOKEN_URI.to_string(),
669                (
670                    StatusCode::OK,
671                    response_body,
672                    TokenQueryParams {
673                        scopes: Some(scopes.join(",")),
674                        recursive: None,
675                    },
676                    Arc::new(Mutex::new(0)),
677                ),
678            ),
679        ]))
680        .await;
681        println!("endpoint = {endpoint}");
682
683        let mdsc = Builder::default().with_endpoint(endpoint).build();
684        let now = std::time::Instant::now();
685        let token = mdsc.token().await?;
686        assert_eq!(token.token, "test-access-token");
687        assert_eq!(token.token_type, "test-token-type");
688        assert!(
689            token
690                .expires_at
691                .is_some_and(|d| d >= now + Duration::from_secs(3600))
692        );
693
694        Ok(())
695    }
696
697    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
698    async fn token_provider_partial() -> TestResult {
699        let scopes = vec!["scope1".to_string()];
700        let response = MDSTokenResponse {
701            access_token: "test-access-token".to_string(),
702            expires_in: None,
703            token_type: "test-token-type".to_string(),
704        };
705        let response_body = serde_json::to_value(&response).unwrap();
706        let (endpoint, _server) = start(Handlers::from([(
707            MDS_TOKEN_URI.to_string(),
708            (
709                StatusCode::OK,
710                response_body,
711                TokenQueryParams {
712                    scopes: Some(scopes.join(",")),
713                    recursive: None,
714                },
715                Arc::new(Mutex::new(0)),
716            ),
717        )]))
718        .await;
719        println!("endpoint = {endpoint}");
720
721        let mdsc = Builder::default()
722            .with_endpoint(endpoint)
723            .with_scopes(scopes)
724            .build();
725        let token = mdsc.token().await?;
726        assert_eq!(token.token, "test-access-token");
727        assert_eq!(token.token_type, "test-token-type");
728        assert_eq!(token.expires_at, None);
729
730        Ok(())
731    }
732
733    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
734    async fn token_provider_retryable_error() -> TestResult {
735        let scopes = vec!["scope1".to_string()];
736        let (endpoint, _server) = start(Handlers::from([(
737            MDS_TOKEN_URI.to_string(),
738            (
739                StatusCode::SERVICE_UNAVAILABLE,
740                serde_json::to_value("try again")?,
741                TokenQueryParams {
742                    scopes: Some(scopes.join(",")),
743                    recursive: None,
744                },
745                Arc::new(Mutex::new(0)),
746            ),
747        )]))
748        .await;
749
750        let mdsc = Builder::default()
751            .with_endpoint(endpoint)
752            .with_scopes(scopes)
753            .build();
754        let e = mdsc.token().await.err().unwrap();
755        assert!(e.is_retryable());
756        assert!(e.source().unwrap().to_string().contains("try again"));
757
758        Ok(())
759    }
760
761    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
762    async fn token_provider_nonretryable_error() -> TestResult {
763        let scopes = vec!["scope1".to_string()];
764        let (endpoint, _server) = start(Handlers::from([(
765            MDS_TOKEN_URI.to_string(),
766            (
767                StatusCode::UNAUTHORIZED,
768                serde_json::to_value("epic fail".to_string())?,
769                TokenQueryParams {
770                    scopes: Some(scopes.join(",")),
771                    recursive: None,
772                },
773                Arc::new(Mutex::new(0)),
774            ),
775        )]))
776        .await;
777
778        let mdsc = Builder::default()
779            .with_endpoint(endpoint)
780            .with_scopes(scopes)
781            .build();
782
783        let e = mdsc.token().await.err().unwrap();
784        assert!(!e.is_retryable());
785        assert!(e.source().unwrap().to_string().contains("epic fail"));
786
787        Ok(())
788    }
789
790    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
791    async fn token_provider_malformed_response_is_nonretryable() -> TestResult {
792        let scopes = vec!["scope1".to_string()];
793        let (endpoint, _server) = start(Handlers::from([(
794            MDS_TOKEN_URI.to_string(),
795            (
796                StatusCode::OK,
797                serde_json::to_value("bad json".to_string())?,
798                TokenQueryParams {
799                    scopes: Some(scopes.join(",")),
800                    recursive: None,
801                },
802                Arc::new(Mutex::new(0)),
803            ),
804        )]))
805        .await;
806
807        let mdsc = Builder::default()
808            .with_endpoint(endpoint)
809            .with_scopes(scopes)
810            .build();
811
812        let e = mdsc.token().await.err().unwrap();
813        assert!(!e.is_retryable());
814
815        Ok(())
816    }
817
818    #[tokio::test]
819    async fn get_default_universe_domain_success() {
820        let universe_domain_response = Builder::default().build().universe_domain().await.unwrap();
821        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
822    }
823
824    #[tokio::test]
825    async fn get_custom_universe_domain_success() {
826        let universe_domain = "test-universe";
827        let universe_domain_response = Builder::default()
828            .with_universe_domain(universe_domain)
829            .build()
830            .universe_domain()
831            .await
832            .unwrap();
833        assert_eq!(universe_domain_response, universe_domain);
834    }
835}