google_cloud_auth/credentials/
mds.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Metadata Service] Credentials type.
16//!
17//! Google Cloud environments such as [Google Compute Engine (GCE)][gce-link],
18//! [Google Kubernetes Engine (GKE)][gke-link], or [Cloud Run] provide a metadata service.
19//! This is a local service to the VM (or pod) which (as the name implies) provides
20//! metadata information about the VM. The service also provides access
21//! tokens associated with the [default service account] for the corresponding
22//! VM.
23//!
24//! The default host name of the metadata service is `metadata.google.internal`.
25//! If you would like to use a different hostname, you can set it using the
26//! `GCE_METADATA_HOST` environment variable.
27//!
28//! You can use this access token to securely authenticate with Google Cloud,
29//! without having to download secrets or other credentials. The types in this
30//! module allow you to retrieve these access tokens, and can be used with
31//! the Google Cloud client libraries for Rust.
32//!
33//! While the Google Cloud client libraries for Rust default to
34//! using the types defined in this module. You may want to use said types directly
35//! to customize some of the properties of these credentials.
36//!
37//! # Example
38//! ```
39//! # use google_cloud_auth::credentials::mds::Builder;
40//! # use google_cloud_auth::credentials::Credentials;
41//! # use http::Extensions;
42//! # tokio_test::block_on(async {
43//! let credentials: Credentials = Builder::default()
44//!     .with_quota_project_id("my-quota-project")
45//!     .build()?;
46//! let headers = credentials.headers(Extensions::new()).await?;
47//! println!("Headers: {headers:?}");
48//! # Ok::<(), anyhow::Error>(())
49//! # });
50//! ```
51//!
52//! [Cloud Run]: https://cloud.google.com/run
53//! [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
54//! [gce-link]: https://cloud.google.com/products/compute
55//! [gke-link]: https://cloud.google.com/kubernetes-engine
56//! [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
57
58use crate::credentials::dynamic::CredentialsProvider;
59use crate::credentials::{CacheableResource, Credentials, DEFAULT_UNIVERSE_DOMAIN};
60use crate::errors::CredentialsError;
61use crate::headers_util::build_cacheable_headers;
62use crate::token::{CachedTokenProvider, Token, TokenProvider};
63use crate::token_cache::TokenCache;
64use crate::{BuildResult, Result};
65use async_trait::async_trait;
66use bon::Builder;
67use http::{Extensions, HeaderMap, HeaderValue};
68use reqwest::Client;
69use std::default::Default;
70use std::sync::Arc;
71use std::time::Duration;
72use tokio::time::Instant;
73
74const METADATA_FLAVOR_VALUE: &str = "Google";
75const METADATA_FLAVOR: &str = "metadata-flavor";
76const METADATA_ROOT: &str = "http://metadata.google.internal";
77const MDS_DEFAULT_URI: &str = "/computeMetadata/v1/instance/service-accounts/default";
78const GCE_METADATA_HOST_ENV_VAR: &str = "GCE_METADATA_HOST";
79// TODO(#2235) - Improve this message by talking about retries when really running with MDS
80const MDS_NOT_FOUND_ERROR: &str = concat!(
81    "Could not fetch an auth token to authenticate with Google Cloud. ",
82    "The most common reason for this problem is that you are not running in a Google Cloud Environment ",
83    "and you have not configured local credentials for development and testing. ",
84    "To setup local credentials, run `gcloud auth application-default login`. ",
85    "More information on how to authenticate client libraries can be found at https://cloud.google.com/docs/authentication/client-libraries"
86);
87
88#[derive(Debug)]
89struct MDSCredentials<T>
90where
91    T: CachedTokenProvider,
92{
93    quota_project_id: Option<String>,
94    universe_domain: Option<String>,
95    token_provider: T,
96}
97
98/// Creates [Credentials] instances backed by the [Metadata Service].
99///
100/// While the Google Cloud client libraries for Rust default to credentials
101/// backed by the metadata service, some applications may need to:
102/// * Customize the metadata service credentials in some way
103/// * Bypass the [Application Default Credentials] lookup and only
104///   use the metadata server credentials
105/// * Use the credentials directly outside the client libraries
106///
107/// [Application Default Credentials]: https://cloud.google.com/docs/authentication/application-default-credentials
108/// [Metadata Service]: https://cloud.google.com/compute/docs/metadata/overview
109#[derive(Debug, Default)]
110pub struct Builder {
111    endpoint: Option<String>,
112    quota_project_id: Option<String>,
113    scopes: Option<Vec<String>>,
114    universe_domain: Option<String>,
115    created_by_adc: bool,
116}
117
118impl Builder {
119    /// Sets the endpoint for this credentials.
120    ///
121    /// A trailing slash is significant, so specify the base URL without a trailing  
122    /// slash. If not set, the credentials use `http://metadata.google.internal`.
123    ///
124    /// # Example
125    /// ```
126    /// # use google_cloud_auth::credentials::mds::Builder;
127    /// # tokio_test::block_on(async {
128    /// let credentials = Builder::default()
129    ///     .with_endpoint("https://metadata.google.foobar")
130    ///     .build();
131    /// # });
132    /// ```
133    pub fn with_endpoint<S: Into<String>>(mut self, endpoint: S) -> Self {
134        self.endpoint = Some(endpoint.into());
135        self
136    }
137
138    /// Set the [quota project] for this credentials.
139    ///
140    /// In some services, you can use a service account in
141    /// one project for authentication and authorization, and charge
142    /// the usage to a different project. This may require that the
143    /// service account has `serviceusage.services.use` permissions on the quota project.
144    ///
145    /// [quota project]: https://cloud.google.com/docs/quotas/quota-project
146    pub fn with_quota_project_id<S: Into<String>>(mut self, quota_project_id: S) -> Self {
147        self.quota_project_id = Some(quota_project_id.into());
148        self
149    }
150
151    /// Sets the universe domain for this credentials.
152    ///
153    /// Client libraries use `universe_domain` to determine
154    /// the API endpoints to use for making requests.
155    /// If not set, then credentials use `${service}.googleapis.com`,
156    /// otherwise they use `${service}.${universe_domain}.
157    pub fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
158        self.universe_domain = Some(universe_domain.into());
159        self
160    }
161
162    /// Sets the [scopes] for this credentials.
163    ///
164    /// Metadata server issues tokens based on the requested scopes.
165    /// If no scopes are specified, the credentials defaults to all
166    /// scopes configured for the [default service account] on the instance.
167    ///
168    /// [default service account]: https://cloud.google.com/iam/docs/service-account-types#default
169    /// [scopes]: https://developers.google.com/identity/protocols/oauth2/scopes
170    pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
171    where
172        I: IntoIterator<Item = S>,
173        S: Into<String>,
174    {
175        self.scopes = Some(scopes.into_iter().map(|s| s.into()).collect());
176        self
177    }
178
179    // This method is used to build mds credentials from ADC
180    pub(crate) fn from_adc() -> Self {
181        Self {
182            created_by_adc: true,
183            ..Default::default()
184        }
185    }
186
187    fn build_token_provider(self) -> MDSAccessTokenProvider {
188        let final_endpoint: String;
189        let endpoint_overridden: bool;
190
191        // Determine the endpoint and whether it was overridden
192        if let Ok(host_from_env) = std::env::var(GCE_METADATA_HOST_ENV_VAR) {
193            // Check GCE_METADATA_HOST environment variable first
194            final_endpoint = format!("http://{}", host_from_env);
195            endpoint_overridden = true;
196        } else if let Some(builder_endpoint) = self.endpoint {
197            // Else, check if an endpoint was provided to the mds::Builder
198            final_endpoint = builder_endpoint;
199            endpoint_overridden = true;
200        } else {
201            // Else, use the default metadata root
202            final_endpoint = METADATA_ROOT.to_string();
203            endpoint_overridden = false;
204        };
205
206        MDSAccessTokenProvider::builder()
207            .endpoint(final_endpoint)
208            .maybe_scopes(self.scopes)
209            .endpoint_overridden(endpoint_overridden)
210            .created_by_adc(self.created_by_adc)
211            .build()
212    }
213
214    /// Returns a [Credentials] instance with the configured settings.
215    pub fn build(self) -> BuildResult<Credentials> {
216        let mdsc = MDSCredentials {
217            quota_project_id: self.quota_project_id.clone(),
218            universe_domain: self.universe_domain.clone(),
219            token_provider: TokenCache::new(self.build_token_provider()),
220        };
221        Ok(Credentials {
222            inner: Arc::new(mdsc),
223        })
224    }
225}
226
227#[async_trait::async_trait]
228impl<T> CredentialsProvider for MDSCredentials<T>
229where
230    T: CachedTokenProvider,
231{
232    async fn headers(&self, extensions: Extensions) -> Result<CacheableResource<HeaderMap>> {
233        let cached_token = self.token_provider.token(extensions).await?;
234        build_cacheable_headers(&cached_token, &self.quota_project_id)
235    }
236
237    async fn universe_domain(&self) -> Option<String> {
238        if self.universe_domain.is_some() {
239            return self.universe_domain.clone();
240        }
241        return Some(DEFAULT_UNIVERSE_DOMAIN.to_string());
242    }
243}
244
245#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
246struct ServiceAccountInfo {
247    email: String,
248    scopes: Option<Vec<String>>,
249    aliases: Option<Vec<String>>,
250}
251
252#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
253struct MDSTokenResponse {
254    access_token: String,
255    #[serde(skip_serializing_if = "Option::is_none")]
256    expires_in: Option<u64>,
257    token_type: String,
258}
259
260#[derive(Debug, Clone, Default, Builder)]
261struct MDSAccessTokenProvider {
262    #[builder(into)]
263    scopes: Option<Vec<String>>,
264    #[builder(into)]
265    endpoint: String,
266    endpoint_overridden: bool,
267    created_by_adc: bool,
268}
269
270impl MDSAccessTokenProvider {
271    // During ADC, if no credentials are found in the well-known location and the GOOGLE_APPLICATION_CREDENTIALS
272    // environment variable is not set, we default to MDS credentials without checking if the code is really
273    // running in an environment with MDS. To help users who got to this state because of lack of credentials
274    // setup on their machines, we provide a detailed error message to them talking about local setup and other
275    // auth mechanisms available to them.
276    // If the endpoint is overridden, even if ADC was used to create the MDS credentials, we do not give a detailed
277    // error message because they deliberately wanted to use an MDS.
278    fn error_message(&self) -> &str {
279        if self.use_adc_message() {
280            MDS_NOT_FOUND_ERROR
281        } else {
282            "failed to fetch token"
283        }
284    }
285
286    fn use_adc_message(&self) -> bool {
287        self.created_by_adc && !self.endpoint_overridden
288    }
289}
290
291#[async_trait]
292impl TokenProvider for MDSAccessTokenProvider {
293    async fn token(&self) -> Result<Token> {
294        let client = Client::new();
295        let request = client
296            .get(format!("{}{}/token", self.endpoint, MDS_DEFAULT_URI))
297            .header(
298                METADATA_FLAVOR,
299                HeaderValue::from_static(METADATA_FLAVOR_VALUE),
300            );
301        // Use the `scopes` option if set, otherwise let the MDS use the default
302        // scopes.
303        let scopes = self.scopes.as_ref().map(|v| v.join(","));
304        let request = scopes
305            .into_iter()
306            .fold(request, |r, s| r.query(&[("scopes", s)]));
307
308        // If the connection to MDS was not successful, it is useful to retry when really
309        // running on MDS environments and not useful if there is no MDS. We will mark the error
310        // as retryable and let the retry policy determine whether to retry or not. Whenever we
311        // define a default retry policy, we can skip retrying this case.
312        let response = request
313            .send()
314            .await
315            .map_err(|e| crate::errors::from_http_error(e, self.error_message()))?;
316        // Process the response
317        if !response.status().is_success() {
318            let err = crate::errors::from_http_response(response, self.error_message()).await;
319            return Err(err);
320        }
321        let response = response.json::<MDSTokenResponse>().await.map_err(|e| {
322            // Decoding errors are not transient. Typically they indicate a badly
323            // configured MDS endpoint, or DNS redirecting the request to a random
324            // server, e.g., ISPs that redirect unknown services to HTTP.
325            CredentialsError::from_source(!e.is_decode(), e)
326        })?;
327        let token = Token {
328            token: response.access_token,
329            token_type: response.token_type,
330            expires_at: response
331                .expires_in
332                .map(|d| Instant::now() + Duration::from_secs(d)),
333            metadata: None,
334        };
335        Ok(token)
336    }
337}
338
339#[cfg(test)]
340mod test {
341    use super::*;
342    use crate::credentials::QUOTA_PROJECT_KEY;
343    use crate::credentials::test::{
344        get_headers_from_cache, get_token_from_headers, get_token_type_from_headers,
345    };
346    use crate::errors;
347    use crate::token::test::MockTokenProvider;
348    use axum::extract::Query;
349    use axum::response::IntoResponse;
350    use http::header::AUTHORIZATION;
351    use reqwest::StatusCode;
352    use reqwest::header::HeaderMap;
353    use scoped_env::ScopedEnv;
354    use serde::Deserialize;
355    use serde_json::Value;
356    use serial_test::{parallel, serial};
357    use std::collections::HashMap;
358    use std::error::Error;
359    use std::sync::Mutex;
360    use test_case::test_case;
361    use tokio::task::JoinHandle;
362    use url::Url;
363
364    type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
365
366    // Define a struct to capture query parameters
367    #[derive(Debug, Clone, Deserialize, PartialEq)]
368    struct TokenQueryParams {
369        scopes: Option<String>,
370        recursive: Option<String>,
371    }
372
373    #[test]
374    fn validate_default_endpoint_urls() {
375        let default_endpoint_address = Url::parse(&format!("{}{}", METADATA_ROOT, MDS_DEFAULT_URI));
376        assert!(default_endpoint_address.is_ok());
377
378        let token_endpoint_address =
379            Url::parse(&format!("{}{}/token", METADATA_ROOT, MDS_DEFAULT_URI));
380        assert!(token_endpoint_address.is_ok());
381    }
382
383    #[tokio::test]
384    async fn headers_success() -> TestResult {
385        let token = Token {
386            token: "test-token".to_string(),
387            token_type: "Bearer".to_string(),
388            expires_at: None,
389            metadata: None,
390        };
391
392        let mut mock = MockTokenProvider::new();
393        mock.expect_token().times(1).return_once(|| Ok(token));
394
395        let mdsc = MDSCredentials {
396            quota_project_id: None,
397            universe_domain: None,
398            token_provider: TokenCache::new(mock),
399        };
400
401        let mut extensions = Extensions::new();
402        let cached_headers = mdsc.headers(extensions.clone()).await.unwrap();
403        let (headers, entity_tag) = match cached_headers {
404            CacheableResource::New { entity_tag, data } => (data, entity_tag),
405            CacheableResource::NotModified => unreachable!("expecting new headers"),
406        };
407        let token = headers.get(AUTHORIZATION).unwrap();
408        assert_eq!(headers.len(), 1, "{headers:?}");
409        assert_eq!(token, HeaderValue::from_static("Bearer test-token"));
410        assert!(token.is_sensitive());
411
412        extensions.insert(entity_tag);
413
414        let cached_headers = mdsc.headers(extensions).await?;
415
416        match cached_headers {
417            CacheableResource::New { .. } => unreachable!("expecting new headers"),
418            CacheableResource::NotModified => CacheableResource::<HeaderMap>::NotModified,
419        };
420        Ok(())
421    }
422
423    #[tokio::test]
424    async fn headers_failure() {
425        let mut mock = MockTokenProvider::new();
426        mock.expect_token()
427            .times(1)
428            .return_once(|| Err(errors::non_retryable_from_str("fail")));
429
430        let mdsc = MDSCredentials {
431            quota_project_id: None,
432            universe_domain: None,
433            token_provider: TokenCache::new(mock),
434        };
435        assert!(mdsc.headers(Extensions::new()).await.is_err());
436    }
437
438    #[test]
439    fn error_message_with_adc() {
440        let provider = MDSAccessTokenProvider::builder()
441            .endpoint("http://127.0.0.1")
442            .created_by_adc(true)
443            .endpoint_overridden(false)
444            .build();
445
446        let want = MDS_NOT_FOUND_ERROR;
447        let got = provider.error_message();
448        assert!(got.contains(want), "{got}, {provider:?}");
449    }
450
451    #[test_case(false, false)]
452    #[test_case(false, true)]
453    #[test_case(true, true)]
454    fn error_message_without_adc(adc: bool, overridden: bool) {
455        let provider = MDSAccessTokenProvider::builder()
456            .endpoint("http://127.0.0.1")
457            .created_by_adc(adc)
458            .endpoint_overridden(overridden)
459            .build();
460
461        let not_want = MDS_NOT_FOUND_ERROR;
462        let got = provider.error_message();
463        assert!(!got.contains(not_want), "{got}, {provider:?}");
464    }
465
466    #[tokio::test]
467    #[serial]
468    async fn adc_no_mds() -> TestResult {
469        let err = Builder::from_adc()
470            .build_token_provider()
471            .token()
472            .await
473            .unwrap_err();
474
475        assert!(err.is_transient(), "{err:?}");
476        assert!(
477            err.to_string().contains("application-default"),
478            "display={err}, debug={err:?}"
479        );
480        let source = err
481            .source()
482            .and_then(|e| e.downcast_ref::<reqwest::Error>());
483        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
484
485        Ok(())
486    }
487
488    #[tokio::test]
489    #[serial]
490    async fn adc_overridden_mds() -> TestResult {
491        let _e = ScopedEnv::set(super::GCE_METADATA_HOST_ENV_VAR, "metadata.overridden");
492
493        let err = Builder::from_adc()
494            .build_token_provider()
495            .token()
496            .await
497            .unwrap_err();
498
499        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
500
501        assert!(err.is_transient(), "{err:?}");
502        assert!(
503            !err.to_string().contains("application-default"),
504            "display={err}, debug={err:?}"
505        );
506        let source = err
507            .source()
508            .and_then(|e| e.downcast_ref::<reqwest::Error>());
509        assert!(matches!(source, Some(e) if e.status().is_none()), "{err:?}");
510
511        Ok(())
512    }
513
514    #[tokio::test]
515    #[serial]
516    async fn builder_no_mds() -> TestResult {
517        let e = Builder::default()
518            .build_token_provider()
519            .token()
520            .await
521            .err()
522            .unwrap();
523
524        assert!(e.is_transient(), "{e:?}");
525        assert!(
526            !format!("{:?}", e.source()).contains("application-default"),
527            "{e:?}"
528        );
529
530        Ok(())
531    }
532
533    fn handle_token_factory(
534        response_code: StatusCode,
535        response_headers: HeaderMap,
536        response_body: Value,
537    ) -> impl IntoResponse {
538        (response_code, response_headers, response_body.to_string()).into_response()
539    }
540
541    type Handlers = HashMap<String, (StatusCode, Value, TokenQueryParams, Arc<Mutex<i32>>)>;
542
543    // Starts a server running locally that responds on multiple paths.
544    // Returns an (endpoint, server) pair.
545    async fn start(path_handlers: Handlers) -> (String, JoinHandle<()>) {
546        let mut app = axum::Router::new();
547
548        for (path, (code, body, expected_query, call_count)) in path_handlers {
549            let header_map = HeaderMap::new();
550            let handler = move |Query(query): Query<TokenQueryParams>| {
551                let body = body.clone();
552                let header_map = header_map.clone();
553                async move {
554                    assert_eq!(expected_query, query);
555                    let mut count = call_count.lock().unwrap();
556                    *count += 1;
557                    handle_token_factory(code, header_map, body)
558                }
559            };
560            app = app.route(&path, axum::routing::get(handler));
561        }
562
563        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
564        let addr = listener.local_addr().unwrap();
565        let server = tokio::spawn(async move {
566            axum::serve(listener, app).await.unwrap();
567        });
568        (format!("http://{}:{}", addr.ip(), addr.port()), server)
569    }
570
571    #[tokio::test]
572    #[serial]
573    async fn test_gce_metadata_host_env_var() {
574        let scopes = ["scope1".to_string(), "scope2".to_string()];
575        let response = MDSTokenResponse {
576            access_token: "test-access-token".to_string(),
577            expires_in: Some(3600),
578            token_type: "test-token-type".to_string(),
579        };
580        let response_body = serde_json::to_value(&response).unwrap();
581
582        let (endpoint, _server) = start(Handlers::from([(
583            format!("{}/token", MDS_DEFAULT_URI),
584            (
585                StatusCode::OK,
586                response_body,
587                TokenQueryParams {
588                    scopes: Some(scopes.join(",")),
589                    recursive: None,
590                },
591                Arc::new(Mutex::new(0)),
592            ),
593        )]))
594        .await;
595
596        // Trim out 'http://' from the endpoint provided by the fake server
597        let _e = ScopedEnv::set(
598            super::GCE_METADATA_HOST_ENV_VAR,
599            endpoint.strip_prefix("http://").unwrap_or(&endpoint),
600        );
601        let mdsc = Builder::default()
602            .with_scopes(["scope1", "scope2"])
603            .build()
604            .unwrap();
605        let headers = mdsc.headers(Extensions::new()).await.unwrap();
606        let _e = ScopedEnv::remove(super::GCE_METADATA_HOST_ENV_VAR);
607
608        assert_eq!(
609            get_token_from_headers(headers).unwrap(),
610            "test-access-token"
611        );
612    }
613
614    #[tokio::test]
615    #[parallel]
616    async fn headers_success_with_quota_project() -> TestResult {
617        let scopes = ["scope1".to_string(), "scope2".to_string()];
618        let response = MDSTokenResponse {
619            access_token: "test-access-token".to_string(),
620            expires_in: Some(3600),
621            token_type: "test-token-type".to_string(),
622        };
623        let response_body = serde_json::to_value(&response).unwrap();
624
625        let (endpoint, _server) = start(Handlers::from([(
626            format!("{}/token", MDS_DEFAULT_URI),
627            (
628                StatusCode::OK,
629                response_body,
630                TokenQueryParams {
631                    scopes: Some(scopes.join(",")),
632                    recursive: None,
633                },
634                Arc::new(Mutex::new(0)),
635            ),
636        )]))
637        .await;
638
639        let mdsc = Builder::default()
640            .with_scopes(["scope1", "scope2"])
641            .with_endpoint(endpoint)
642            .with_quota_project_id("test-project")
643            .build()?;
644
645        let headers = get_headers_from_cache(mdsc.headers(Extensions::new()).await.unwrap())?;
646        let token = headers.get(AUTHORIZATION).unwrap();
647        let quota_project = headers.get(QUOTA_PROJECT_KEY).unwrap();
648
649        assert_eq!(headers.len(), 2, "{headers:?}");
650        assert_eq!(
651            token,
652            HeaderValue::from_static("test-token-type test-access-token")
653        );
654        assert!(token.is_sensitive());
655        assert_eq!(quota_project, HeaderValue::from_static("test-project"));
656        assert!(!quota_project.is_sensitive());
657
658        Ok(())
659    }
660
661    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
662    #[parallel]
663    async fn token_caching() -> TestResult {
664        let scopes = vec!["scope1".to_string()];
665        let response = MDSTokenResponse {
666            access_token: "test-access-token".to_string(),
667            expires_in: Some(3600),
668            token_type: "test-token-type".to_string(),
669        };
670        let response_body = serde_json::to_value(&response).unwrap();
671
672        let call_count = Arc::new(Mutex::new(0));
673        let (endpoint, _server) = start(Handlers::from([(
674            format!("{}/token", MDS_DEFAULT_URI),
675            (
676                StatusCode::OK,
677                response_body,
678                TokenQueryParams {
679                    scopes: Some(scopes.join(",")),
680                    recursive: None,
681                },
682                call_count.clone(),
683            ),
684        )]))
685        .await;
686
687        let mdsc = Builder::default()
688            .with_scopes(scopes)
689            .with_endpoint(endpoint)
690            .build()?;
691        let headers = mdsc.headers(Extensions::new()).await?;
692        assert_eq!(
693            get_token_from_headers(headers).unwrap(),
694            "test-access-token"
695        );
696        let headers = mdsc.headers(Extensions::new()).await?;
697        assert_eq!(
698            get_token_from_headers(headers).unwrap(),
699            "test-access-token"
700        );
701
702        // validate that the inner token provider is called only once
703        assert_eq!(*call_count.lock().unwrap(), 1);
704
705        Ok(())
706    }
707
708    #[tokio::test(start_paused = true)]
709    #[parallel]
710    async fn token_provider_full() -> TestResult {
711        let scopes = vec!["scope1".to_string()];
712        let response = MDSTokenResponse {
713            access_token: "test-access-token".to_string(),
714            expires_in: Some(3600),
715            token_type: "test-token-type".to_string(),
716        };
717        let response_body = serde_json::to_value(&response).unwrap();
718
719        let (endpoint, _server) = start(Handlers::from([(
720            format!("{}/token", MDS_DEFAULT_URI),
721            (
722                StatusCode::OK,
723                response_body,
724                TokenQueryParams {
725                    scopes: Some(scopes.join(",")),
726                    recursive: None,
727                },
728                Arc::new(Mutex::new(0)),
729            ),
730        )]))
731        .await;
732        println!("endpoint = {endpoint}");
733
734        let token = Builder::default()
735            .with_endpoint(endpoint)
736            .with_scopes(scopes)
737            .build_token_provider()
738            .token()
739            .await?;
740
741        let now = tokio::time::Instant::now();
742        assert_eq!(token.token, "test-access-token");
743        assert_eq!(token.token_type, "test-token-type");
744        assert!(
745            token
746                .expires_at
747                .is_some_and(|d| d >= now + Duration::from_secs(3600))
748        );
749
750        Ok(())
751    }
752
753    #[tokio::test(start_paused = true)]
754    #[parallel]
755    async fn token_provider_full_no_scopes() -> TestResult {
756        let response = MDSTokenResponse {
757            access_token: "test-access-token".to_string(),
758            expires_in: Some(3600),
759            token_type: "test-token-type".to_string(),
760        };
761        let response_body = serde_json::to_value(&response).unwrap();
762
763        let (endpoint, _server) = start(Handlers::from([(
764            format!("{}/token", MDS_DEFAULT_URI),
765            (
766                StatusCode::OK,
767                response_body,
768                TokenQueryParams {
769                    scopes: None,
770                    recursive: None,
771                },
772                Arc::new(Mutex::new(0)),
773            ),
774        )]))
775        .await;
776        println!("endpoint = {endpoint}");
777        let token = Builder::default()
778            .with_endpoint(endpoint)
779            .build_token_provider()
780            .token()
781            .await?;
782
783        let now = Instant::now();
784        assert_eq!(token.token, "test-access-token");
785        assert_eq!(token.token_type, "test-token-type");
786        assert!(
787            token
788                .expires_at
789                .is_some_and(|d| d == now + Duration::from_secs(3600))
790        );
791
792        Ok(())
793    }
794
795    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
796    #[parallel]
797    async fn credential_provider_full() -> TestResult {
798        let scopes = vec!["scope1".to_string()];
799        let response = MDSTokenResponse {
800            access_token: "test-access-token".to_string(),
801            expires_in: None,
802            token_type: "test-token-type".to_string(),
803        };
804        let response_body = serde_json::to_value(&response).unwrap();
805        let (endpoint, _server) = start(Handlers::from([(
806            format!("{}/token", MDS_DEFAULT_URI),
807            (
808                StatusCode::OK,
809                response_body,
810                TokenQueryParams {
811                    scopes: Some(scopes.join(",")),
812                    recursive: None,
813                },
814                Arc::new(Mutex::new(0)),
815            ),
816        )]))
817        .await;
818        println!("endpoint = {endpoint}");
819
820        let mdsc = Builder::default()
821            .with_endpoint(endpoint)
822            .with_scopes(scopes)
823            .build()?;
824        let headers = mdsc.headers(Extensions::new()).await?;
825        assert_eq!(
826            get_token_from_headers(headers.clone()).unwrap(),
827            "test-access-token"
828        );
829        assert_eq!(
830            get_token_type_from_headers(headers).unwrap(),
831            "test-token-type"
832        );
833
834        Ok(())
835    }
836
837    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
838    #[parallel]
839    async fn credentials_headers_retryable_error() -> TestResult {
840        let scopes = vec!["scope1".to_string()];
841        let (endpoint, _server) = start(Handlers::from([(
842            format!("{}/token", MDS_DEFAULT_URI),
843            (
844                StatusCode::SERVICE_UNAVAILABLE,
845                serde_json::to_value("try again")?,
846                TokenQueryParams {
847                    scopes: Some(scopes.join(",")),
848                    recursive: None,
849                },
850                Arc::new(Mutex::new(0)),
851            ),
852        )]))
853        .await;
854
855        let mdsc = Builder::default()
856            .with_endpoint(endpoint)
857            .with_scopes(scopes)
858            .build()?;
859        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
860        assert!(err.is_transient());
861        assert!(err.to_string().contains("try again"), "{err:?}");
862        let source = err
863            .source()
864            .and_then(|e| e.downcast_ref::<reqwest::Error>());
865        assert!(
866            matches!(source, Some(e) if e.status() == Some(StatusCode::SERVICE_UNAVAILABLE)),
867            "{err:?}"
868        );
869
870        Ok(())
871    }
872
873    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
874    #[parallel]
875    async fn credentials_headers_nonretryable_error() -> TestResult {
876        let scopes = vec!["scope1".to_string()];
877        let (endpoint, _server) = start(Handlers::from([(
878            format!("{}/token", MDS_DEFAULT_URI),
879            (
880                StatusCode::UNAUTHORIZED,
881                serde_json::to_value("epic fail".to_string())?,
882                TokenQueryParams {
883                    scopes: Some(scopes.join(",")),
884                    recursive: None,
885                },
886                Arc::new(Mutex::new(0)),
887            ),
888        )]))
889        .await;
890
891        let mdsc = Builder::default()
892            .with_endpoint(endpoint)
893            .with_scopes(scopes)
894            .build()?;
895
896        let err = mdsc.headers(Extensions::new()).await.unwrap_err();
897        assert!(!err.is_transient());
898        assert!(err.to_string().contains("epic fail"), "{err:?}");
899        let source = err
900            .source()
901            .and_then(|e| e.downcast_ref::<reqwest::Error>());
902        assert!(
903            matches!(source, Some(e) if e.status() == Some(StatusCode::UNAUTHORIZED)),
904            "{err:?}"
905        );
906
907        Ok(())
908    }
909
910    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
911    #[parallel]
912    async fn credentials_headers_malformed_response_is_nonretryable() -> TestResult {
913        let scopes = vec!["scope1".to_string()];
914        let (endpoint, _server) = start(Handlers::from([(
915            format!("{}/token", MDS_DEFAULT_URI),
916            (
917                StatusCode::OK,
918                serde_json::to_value("bad json".to_string())?,
919                TokenQueryParams {
920                    scopes: Some(scopes.join(",")),
921                    recursive: None,
922                },
923                Arc::new(Mutex::new(0)),
924            ),
925        )]))
926        .await;
927
928        let mdsc = Builder::default()
929            .with_endpoint(endpoint)
930            .with_scopes(scopes)
931            .build()?;
932
933        let e = mdsc.headers(Extensions::new()).await.err().unwrap();
934        assert!(!e.is_transient());
935
936        Ok(())
937    }
938
939    #[tokio::test]
940    async fn get_default_universe_domain_success() -> TestResult {
941        let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
942        assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
943        Ok(())
944    }
945
946    #[tokio::test]
947    async fn get_custom_universe_domain_success() -> TestResult {
948        let universe_domain = "test-universe";
949        let universe_domain_response = Builder::default()
950            .with_universe_domain(universe_domain)
951            .build()?
952            .universe_domain()
953            .await
954            .unwrap();
955        assert_eq!(universe_domain_response, universe_domain);
956
957        Ok(())
958    }
959}